diff --git a/.mockery.yaml b/.mockery.yaml index fc3af65..b674007 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -28,24 +28,19 @@ packages: interfaces: KekRepository: {} KekUseCase: {} - DekRepository: {} - DekUseCase: {} github.com/allisson/secrets/internal/database: interfaces: TxManager: {} github.com/allisson/secrets/internal/secrets/usecase: interfaces: - DekRepository: {} SecretRepository: {} SecretUseCase: {} github.com/allisson/secrets/internal/transit/usecase: interfaces: - DekRepository: {} TransitKeyRepository: {} TransitKeyUseCase: {} github.com/allisson/secrets/internal/tokenization/usecase: interfaces: - DekRepository: {} TokenizationKeyRepository: {} TokenizationKeyUseCase: {} TokenRepository: {} diff --git a/CONTEXT.md b/CONTEXT.md new file mode 100644 index 0000000..125c0c3 --- /dev/null +++ b/CONTEXT.md @@ -0,0 +1,109 @@ +# CONTEXT + +Domain vocabulary for the `secrets` service. Use these terms exactly in code, +docs, commit messages, and conversations. Drift weakens the seams. + +## Cryptography + +### Envelope encryption +The pattern in which a user payload is encrypted under a **Data Encryption Key +(DEK)**, and the DEK itself is encrypted under a **Key Encryption Key (KEK)**, +which is in turn protected by a **Master Key** held by an external **KMS**. +See [ADR-0001](docs/adr/0001-envelope-encryption-model.md). + +### Master Key +A symmetric key held outside this service in a KMS (AWS, GCP, Azure, or +`localsecrets://` for development). Never stored in the database. Loaded at +boot via `KMSKeeper.Decrypt` and held in a `MasterKeyChain` for the process +lifetime. + +### KEK — Key Encryption Key +A symmetric key that exists only to encrypt DEKs. Persisted in the `keks` +table as ciphertext (wrapped by a Master Key). Rotates on demand via the +KEK rotation worker. The set of all loaded KEKs is the `KekChain`; the +newest is the *active* KEK. + +### DEK — Data Encryption Key +A symmetric key that encrypts exactly one piece of user data: +- in `secrets` and `tokenization`: one DEK per stored row (fresh each call); +- in `transit`: one DEK per Transit Key, reused across user requests. + +Persisted in the `deks` table as ciphertext (wrapped by a KEK). Identified +by a UUIDv7 (`DekID`). + +### AEAD +Authenticated Encryption with Associated Data. The service supports +`aes-256-gcm` and `chacha20-poly1305`. Algorithm is chosen at DEK creation +and recorded on the envelope; ciphertext from one algorithm cannot be +decrypted by another. + +### Rewrap +Re-encrypting an existing DEK under a newer KEK without changing the +underlying DEK key material. Used by the KEK rotation worker so old +ciphertexts remain decryptable without bulk re-encryption. + +## Modules + +### Keyring +**The single module that owns envelope encryption.** Exposes a small interface +to feature modules; hides KEK chain, DEK lifecycle, AEAD selection, and KMS +calls behind it. Call sites do not know KEK from DEK. + +- `Encrypt(ctx, plaintext) → Envelope` — fresh-DEK envelope encryption. + Used by `secrets` and `tokenization`. +- `Decrypt(ctx, envelope) → plaintext` — inverse of `Encrypt`. +- `AllocateDek(ctx, alg) → DekHandle` — persists a DEK and returns an + opaque handle. Used by `transit` once per Transit Key. +- `EncryptWith(ctx, handle, plaintext, aad) → (ciphertext, nonce)` — encrypt + under a previously-allocated DEK. Used by `transit` per request. +- `DecryptWith(ctx, handle, ciphertext, nonce, aad) → plaintext` — inverse. +- `Rewrap(ctx, dekID)` — rewrap a DEK under the active KEK. Used by the + rotation worker. + +### Envelope +The value returned by `Keyring.Encrypt`. Contains `DekID`, `Ciphertext`, +`Nonce`, and `Algorithm`. Features persist all four fields; nothing else +about the DEK or KEK is leaked to callers. + +### DekHandle +An opaque reference to a persistent DEK held by Keyring. Returned by +`AllocateDek`, accepted by `EncryptWith` / `DecryptWith`. Features store +only the handle's `DekID` and reload it on demand. Used to model the +`transit` flow where many user requests share one DEK. + +### Transit Key +A named, long-lived encryption key managed via the transit HTTP API. +Backed internally by a single DEK (a DekHandle). Users call `encrypt` and +`decrypt` against the name; the DEK never leaves Keyring. + +### Tokenization Key +A named encryption key associated with a token format (UUID, numeric, +alphanumeric, Luhn) and a determinism flag. Each tokenize call still uses +a fresh DEK via `Keyring.Encrypt` — the Tokenization Key itself is +metadata + format rules, not a long-lived crypto key. + +### Secret +A path-addressed, versioned encrypted payload. Each version has its own +DEK (fresh per write). The path is the lookup key; the latest version is +the default read. + +### KEK rotation worker +A background job that calls `Keyring.Rewrap` for every DEK not yet +encrypted under the active KEK. Runs after `Keyring.RotateKek`. Idempotent. + +## Storage + +### `keks` table +Wrapped KEK material, ordered by `version`. The highest-version, non-revoked +row is the active KEK. + +### `deks` table +Wrapped DEK material, joined to the `keks` row that wrapped them. Indexed +by `kek_id` to support the rotation worker's batch query. + +### Transactions +All multi-row writes (creating a DEK + the row that references it) happen +inside a `database.TxManager` transaction propagated via `context.Context` +(per [ADR-0005](docs/adr/0005-context-based-transaction-management.md)). +`Keyring.Encrypt` and `Keyring.AllocateDek` join the caller's transaction +when one is present. diff --git a/cmd/app/commands/rewrap_deks.go b/cmd/app/commands/rewrap_deks.go index 6eab5e8..fa57be2 100644 --- a/cmd/app/commands/rewrap_deks.go +++ b/cmd/app/commands/rewrap_deks.go @@ -7,23 +7,21 @@ import ( "github.com/google/uuid" - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoUseCase "github.com/allisson/secrets/internal/crypto/usecase" + "github.com/allisson/secrets/internal/keyring" ) -// RunRewrapDeks finds all DEKs that are not encrypted with the specified KEK ID -// and re-encrypts them with the specified KEK in batches. +// RunRewrapDeks finds DEKs not encrypted with the keyring's active KEK and +// rewraps them in batches. The kekIDStr argument is a safety check: it must +// match the keyring's currently-active KEK, so an operator cannot accidentally +// rewrap DEKs against a stale chain. func RunRewrapDeks( ctx context.Context, - masterKeyChain *cryptoDomain.MasterKeyChain, - kekUseCase cryptoUseCase.KekUseCase, - dekUseCase cryptoUseCase.DekUseCase, + kr keyring.Keyring, logger *slog.Logger, kekIDStr string, batchSize int, ) error { - // Parse KEK ID - newKekID, err := uuid.Parse(kekIDStr) + wantedKekID, err := uuid.Parse(kekIDStr) if err != nil { return fmt.Errorf("invalid kek-id: %w", err) } @@ -32,38 +30,27 @@ func RunRewrapDeks( return fmt.Errorf("batch-size must be greater than 0") } + activeKekID := kr.ActiveKekID() + if activeKekID != wantedKekID { + return fmt.Errorf( + "requested kek-id %s does not match keyring active KEK %s; "+ + "restart the rewrap process after KEK rotation so the latest chain is loaded", + wantedKekID, activeKekID, + ) + } + logger.Info("starting DEK rewrap process", slog.String("kek_id", kekIDStr), slog.Int("batch_size", batchSize), ) - kekChain, err := kekUseCase.Unwrap(ctx, masterKeyChain) + total, err := kr.RewrapAll(ctx, batchSize) if err != nil { - return fmt.Errorf("failed to load and unwrap kek chain: %w", err) - } - defer kekChain.Close() - - totalRewrapped := 0 - - for { - rewrappedCount, err := dekUseCase.Rewrap(ctx, kekChain, newKekID, batchSize) - if err != nil { - return fmt.Errorf("failed to rewrap DEKs in batch: %w", err) - } - - if rewrappedCount == 0 { - break - } - - totalRewrapped += rewrappedCount - logger.Info("rewrapped batch of DEKs", - slog.Int("rewrapped_in_batch", rewrappedCount), - slog.Int("total_rewrapped", totalRewrapped), - ) + return fmt.Errorf("failed to rewrap DEKs: %w", err) } logger.Info("DEK rewrap process completed", - slog.Int("total_rewrapped", totalRewrapped), + slog.Int("total_rewrapped", total), slog.String("target_kek_id", kekIDStr), ) diff --git a/cmd/app/commands/rewrap_deks_test.go b/cmd/app/commands/rewrap_deks_test.go index 0372920..5326b31 100644 --- a/cmd/app/commands/rewrap_deks_test.go +++ b/cmd/app/commands/rewrap_deks_test.go @@ -1,43 +1,64 @@ -package commands +package commands_test import ( "context" + "io" "log/slog" "testing" "github.com/google/uuid" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoMocks "github.com/allisson/secrets/internal/crypto/usecase/mocks" + "github.com/allisson/secrets/cmd/app/commands" + "github.com/allisson/secrets/internal/keyring" ) -func TestRunRewrapDeks(t *testing.T) { - ctx := context.Background() - logger := slog.Default() - masterKeyChain := cryptoDomain.NewMasterKeyChain("test-master-key") - kekID := uuid.New() - kekIDStr := kekID.String() - - t.Run("success", func(t *testing.T) { - mockKekUseCase := &cryptoMocks.MockKekUseCase{} - mockDekUseCase := &cryptoMocks.MockDekUseCase{} - kekChain := cryptoDomain.NewKekChain(nil) - - mockKekUseCase.On("Unwrap", ctx, masterKeyChain).Return(kekChain, nil) - mockDekUseCase.On("Rewrap", ctx, kekChain, kekID, 100).Return(10, nil).Once() - mockDekUseCase.On("Rewrap", ctx, kekChain, kekID, 100).Return(0, nil).Once() - - err := RunRewrapDeks(ctx, masterKeyChain, mockKekUseCase, mockDekUseCase, logger, kekIDStr, 100) - require.NoError(t, err) - - mockKekUseCase.AssertExpectations(t) - mockDekUseCase.AssertExpectations(t) - }) - - t.Run("invalid-kek-id", func(t *testing.T) { - err := RunRewrapDeks(ctx, masterKeyChain, nil, nil, logger, "invalid", 100) - require.Error(t, err) - require.Contains(t, err.Error(), "invalid kek-id") - }) +func TestRunRewrapDeks_InvalidKekID(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + fake := keyring.NewFake() + + err := commands.RunRewrapDeks(context.Background(), fake, logger, "not-a-uuid", 100) + assert.ErrorContains(t, err, "invalid kek-id") +} + +func TestRunRewrapDeks_InvalidBatchSize(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + fake := keyring.NewFake() + + err := commands.RunRewrapDeks( + context.Background(), + fake, + logger, + uuid.Nil.String(), // Fake's ActiveKekID() returns Nil + 0, + ) + assert.ErrorContains(t, err, "batch-size must be greater than 0") +} + +func TestRunRewrapDeks_MismatchedActiveKek(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + fake := keyring.NewFake() + + err := commands.RunRewrapDeks( + context.Background(), + fake, + logger, + uuid.New().String(), // doesn't match Fake's Nil active id + 100, + ) + assert.ErrorContains(t, err, "does not match keyring active KEK") +} + +func TestRunRewrapDeks_SuccessNoDEKs(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + fake := keyring.NewFake() + + err := commands.RunRewrapDeks( + context.Background(), + fake, + logger, + uuid.Nil.String(), + 100, + ) + assert.NoError(t, err) } diff --git a/cmd/app/key_commands.go b/cmd/app/key_commands.go index d755d8b..1bd96cf 100644 --- a/cmd/app/key_commands.go +++ b/cmd/app/key_commands.go @@ -190,26 +190,14 @@ func getKeyCommands() []*cli.Command { return commands.ExecuteWithContainer( ctx, func(ctx context.Context, container *app.Container) error { - masterKeyChain, err := container.MasterKeyChain(ctx) - if err != nil { - return err - } - - kekUseCase, err := container.KekUseCase(ctx) - if err != nil { - return err - } - - dekUseCase, err := container.CryptoDekUseCase(ctx) + kr, err := container.Keyring(ctx) if err != nil { return err } return commands.RunRewrapDeks( ctx, - masterKeyChain, - kekUseCase, - dekUseCase, + kr, container.Logger(), cmd.String("kek-id"), int(cmd.Int("batch-size")), diff --git a/docs/adr/0001-envelope-encryption-model.md b/docs/adr/0001-envelope-encryption-model.md index 8742af0..2e71fba 100644 --- a/docs/adr/0001-envelope-encryption-model.md +++ b/docs/adr/0001-envelope-encryption-model.md @@ -23,9 +23,19 @@ Use envelope encryption hierarchy: - historical versions remain decryptable with prior key material - clear separation between root trust, key-wrapping, and data encryption roles +## Module structure + +The envelope-encryption model is implemented by a single module, +`internal/keyring`, introduced in [ADR 0013](0013-keyring-as-envelope-encryption-module.md). +Features (secrets, transit, tokenization) call `Keyring.Encrypt` / +`Keyring.Decrypt` (or the persistent-DEK pair `AllocateDek` / +`EncryptWith` / `DecryptWith`) and do not handle the KEK chain, DEK +lifecycle, or AEAD selection directly. + ## See also - [Architecture](../concepts/architecture.md) - [Security model](../concepts/security-model.md) - [Key management operations](../operations/kms/key-management.md) - [ADR 0012: PostgreSQL-Only Database](0012-postgresql-only-database.md) - Database storage for encrypted key material +- [ADR 0013: Keyring as Envelope Encryption Module](0013-keyring-as-envelope-encryption-module.md) - Single-module implementation of this model diff --git a/docs/adr/0013-keyring-as-envelope-encryption-module.md b/docs/adr/0013-keyring-as-envelope-encryption-module.md new file mode 100644 index 0000000..2a9ecd9 --- /dev/null +++ b/docs/adr/0013-keyring-as-envelope-encryption-module.md @@ -0,0 +1,133 @@ +# ADR 0013: Keyring as Envelope Encryption Module + +> Status: accepted +> Date: 2026-05-23 + +## Context + +[ADR 0001](0001-envelope-encryption-model.md) defines the cryptographic +hierarchy (Master Key → KEK → DEK → data) but does not say where the +implementation lives. Initially the model was scattered across: + +- `internal/crypto/domain` — types for KEK, DEK, MasterKeyChain +- `internal/crypto/service` — AEAD, AEADManager, KeyManager +- `internal/crypto/usecase` — KekUseCase, DekUseCase +- `internal/crypto/repository` — KEK and DEK persistence + +Each feature module (`secrets`, `transit`, `tokenization`) imported +four of these and reimplemented the same six-step envelope dance per +operation: + +```text +get active KEK → CreateDek under KEK → persist DEK → DecryptDek → +CreateCipher with DEK key → Encrypt plaintext under cipher +``` + +This had three concrete costs: + +1. **The interface is the pattern.** ADR-0001 was enforced by + convention, not by a module boundary. Three near-identical copies of + the dance existed (secrets, transit, tokenization), each with its + own subtle differences. +2. **Single-adapter interfaces.** Every layer of the dance exposed + interfaces (DekRepository, AEADManager, KeyManager) with exactly one + production adapter and one generated mock. The mocks were the + second "adapter," but they only existed for test isolation — they + were not a real seam. +3. **Constructor bloat.** Feature usecases took 6–8 dependencies, most + of them crypto plumbing rather than feature business logic. + +## Decision + +Introduce a single deep module, `internal/keyring`, that owns envelope +encryption end-to-end. Features depend only on the `Keyring` +interface and the small `Envelope` / `DekHandle` value types. + +The interface exposes two encryption shapes: + +```go +// Fresh-DEK envelope: used by secrets and tokenization where each +// stored item gets its own DEK. +Encrypt(ctx, plaintext) → Envelope +Decrypt(ctx, env) → plaintext + +// Persistent DEK: used by transit and tokenization-keys where a single +// DEK wraps many plaintexts over its lifetime. +AllocateDek(ctx, alg) → DekHandle +EncryptWith(ctx, handle, plaintext, aad) → (ciphertext, nonce) +DecryptWith(ctx, handle, ciphertext, nonce, aad) → plaintext + +// KEK rotation: +Rewrap(ctx, dekID) // single DEK +RewrapAll(ctx, batchSize) // batch worker +ActiveKekID() // safety check for the rotation CLI +``` + +`Envelope` is `{DekID, Ciphertext, Nonce}`. Features persist these +three fields and nothing else about crypto state. `DekHandle` is an +opaque `{DekID}` returned by `AllocateDek` and consumed by the +EncryptWith/DecryptWith pair. + +Keyring is constructed once at boot from a loaded `KekChain`, the +shared `KeyManager` and `AEADManager`, and a concrete +`*cryptoRepository.DekRepository`. The KEK chain, AEAD selection, DEK +persistence, and key zeroing all live inside the module. + +A `keyring.Fake` is shipped alongside the production implementation. It +is deterministic, in-memory, and gives features a real second adapter +— making the seam a genuine seam (per the "one adapter = hypothetical +seam, two adapters = real seam" principle) rather than only mock +scaffolding. + +`internal/crypto/usecase/DekUseCase` and the per-feature +`DekRepository` interfaces are deleted; their behaviour moves into the +Keyring. `internal/crypto/usecase/KekUseCase` remains for the KEK +lifecycle CLI (Create, Rotate, Unwrap), which runs outside the +request-time path Keyring serves. + +## Consequences + +### Positive + +- ADR-0001's envelope-encryption hierarchy is enforced by a module + boundary, not by per-feature convention. +- Feature usecases shed 4–5 of their crypto dependencies. Constructor + signatures shrink dramatically (secrets: 8→4, tokenization key: 5→3, + transit: 6→3). +- `keyring.Fake` is the first real second adapter in the codebase. + Feature unit tests now exercise behaviour ("a secret round-trips") + rather than asserting the call sequence against five mocks. The + full feature test suite shrank by ~5,000 lines across secrets, + tokenization, and transit. +- Adding a new feature that needs envelope encryption is now trivial: + inject Keyring, call `Encrypt`/`Decrypt`. No knowledge of KEK, DEK, + AEAD, or KMS required at the call site. +- Memory zeroing of DEK plaintext material is centralised in one + place. Features cannot leak a DEK key by forgetting to call + `Zero()`. + +### Negative + +- Keyring takes a concrete `*cryptoRepository.DekRepository` rather + than the narrower `internal/crypto/usecase.DekRepository` interface. + This is fine — Keyring needs `Create + Get + Update + + GetBatchNotKekID` and only the concrete struct provides all four + today. +- The boot-time `KekChain` is loaded once and cached on the running + Keyring. The rewrap CLI must therefore run in a freshly-booted + process after KEK rotation; this is enforced by + `Keyring.ActiveKekID()` matching the operator-provided + `--kek-id` argument. +- The `nonceSize` used in transit's wire format (12 bytes) is + hardcoded in `internal/transit/usecase`. Both currently-supported + algorithms (AES-256-GCM, ChaCha20-Poly1305) use 12-byte nonces; a + future algorithm with a different nonce size would need either a + size accessor on Keyring or a small refactor of the transit blob + format. + +## See also + +- [ADR 0001: Envelope Encryption Model](0001-envelope-encryption-model.md) +- [ADR 0002: Transit Versioned Ciphertext Contract](0002-transit-versioned-ciphertext-contract.md) - Transit wire format that sits on top of Keyring's EncryptWith/DecryptWith +- [ADR 0005: Context-Based Transaction Management](0005-context-based-transaction-management.md) - Keyring's Encrypt/AllocateDek join the caller's transaction via ctx +- [`CONTEXT.md`](../../CONTEXT.md) - Keyring, Envelope, DekHandle, Rewrap vocabulary diff --git a/internal/app/di.go b/internal/app/di.go index a13b9b4..35c9974 100644 --- a/internal/app/di.go +++ b/internal/app/di.go @@ -18,6 +18,7 @@ import ( cryptoUseCase "github.com/allisson/secrets/internal/crypto/usecase" "github.com/allisson/secrets/internal/database" "github.com/allisson/secrets/internal/http" + "github.com/allisson/secrets/internal/keyring" "github.com/allisson/secrets/internal/metrics" secretsHTTP "github.com/allisson/secrets/internal/secrets/http" secretsUseCase "github.com/allisson/secrets/internal/secrets/usecase" @@ -51,23 +52,21 @@ type Container struct { secretService authService.SecretService tokenService authService.TokenService + // Keyring (envelope encryption) + keyring keyring.Keyring + // Repositories kekRepository cryptoUseCase.KekRepository - cryptoDekRepository cryptoUseCase.DekRepository - dekRepository secretsUseCase.DekRepository secretRepository secretsUseCase.SecretRepository clientRepository authUseCase.ClientRepository tokenRepository authUseCase.TokenRepository auditLogRepository authUseCase.AuditLogRepository transitKeyRepository transitUseCase.TransitKeyRepository - transitDekRepository transitUseCase.DekRepository tokenizationKeyRepository tokenizationUseCase.TokenizationKeyRepository tokenizationTokenRepository tokenizationUseCase.TokenRepository - tokenizationDekRepository tokenizationUseCase.DekRepository // Use Cases kekUseCase cryptoUseCase.KekUseCase - cryptoDekUseCase cryptoUseCase.DekUseCase secretUseCase secretsUseCase.SecretUseCase clientUseCase authUseCase.ClientUseCase tokenUseCase authUseCase.TokenUseCase @@ -102,20 +101,16 @@ type Container struct { kmsServiceInit sync.Once secretServiceInit sync.Once tokenServiceInit sync.Once + keyringInit sync.Once kekRepositoryInit sync.Once - cryptoDekRepositoryInit sync.Once - dekRepositoryInit sync.Once secretRepositoryInit sync.Once clientRepositoryInit sync.Once tokenRepositoryInit sync.Once auditLogRepositoryInit sync.Once transitKeyRepositoryInit sync.Once - transitDekRepositoryInit sync.Once tokenizationKeyRepositoryInit sync.Once tokenizationTokenRepositoryInit sync.Once - tokenizationDekRepositoryInit sync.Once kekUseCaseInit sync.Once - cryptoDekUseCaseInit sync.Once secretUseCaseInit sync.Once clientUseCaseInit sync.Once tokenUseCaseInit sync.Once diff --git a/internal/app/di_crypto.go b/internal/app/di_crypto.go index f2d8900..9268ed0 100644 --- a/internal/app/di_crypto.go +++ b/internal/app/di_crypto.go @@ -8,8 +8,50 @@ import ( cryptoRepository "github.com/allisson/secrets/internal/crypto/repository" cryptoService "github.com/allisson/secrets/internal/crypto/service" cryptoUseCase "github.com/allisson/secrets/internal/crypto/usecase" + "github.com/allisson/secrets/internal/keyring" ) +// Keyring returns the envelope-encryption keyring shared by all features. +func (c *Container) Keyring(ctx context.Context) (keyring.Keyring, error) { + var err error + c.keyringInit.Do(func() { + c.keyring, err = c.initKeyring(ctx) + if err != nil { + c.initErrors.Store("keyring", err) + } + }) + if err != nil { + return nil, err + } + if val, ok := c.initErrors.Load("keyring"); ok { + return nil, val.(error) + } + return c.keyring, nil +} + +func (c *Container) initKeyring(ctx context.Context) (keyring.Keyring, error) { + db, err := c.DB(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get database for keyring: %w", err) + } + + kekChain, err := c.loadKekChain(ctx) + if err != nil { + return nil, fmt.Errorf("failed to load kek chain for keyring: %w", err) + } + + // Keyring needs Create/Get/Update on DEKs. The concrete repository + // provides all three; cryptoUseCase.DekRepository (Update + batch lookup) + // is the narrower interface used by KEK rotation only. + return keyring.New( + kekChain, + cryptoRepository.NewDekRepository(db), + c.AEADManager(), + c.KeyManager(), + cryptoDomain.AESGCM, + ), nil +} + // MasterKeyChain returns the master key chain loaded from environment variables. func (c *Container) MasterKeyChain(ctx context.Context) (*cryptoDomain.MasterKeyChain, error) { var err error @@ -88,42 +130,6 @@ func (c *Container) KekUseCase(ctx context.Context) (cryptoUseCase.KekUseCase, e return c.kekUseCase, nil } -// CryptoDekRepository returns the DEK repository for the crypto use case. -func (c *Container) CryptoDekRepository(ctx context.Context) (cryptoUseCase.DekRepository, error) { - var err error - c.cryptoDekRepositoryInit.Do(func() { - c.cryptoDekRepository, err = c.initCryptoDekRepository(ctx) - if err != nil { - c.initErrors.Store("cryptoDekRepository", err) - } - }) - if err != nil { - return nil, err - } - if val, ok := c.initErrors.Load("cryptoDekRepository"); ok { - return nil, val.(error) - } - return c.cryptoDekRepository, nil -} - -// CryptoDekUseCase returns the DEK use case for the crypto module. -func (c *Container) CryptoDekUseCase(ctx context.Context) (cryptoUseCase.DekUseCase, error) { - var err error - c.cryptoDekUseCaseInit.Do(func() { - c.cryptoDekUseCase, err = c.initCryptoDekUseCase(ctx) - if err != nil { - c.initErrors.Store("cryptoDekUseCase", err) - } - }) - if err != nil { - return nil, err - } - if val, ok := c.initErrors.Load("cryptoDekUseCase"); ok { - return nil, val.(error) - } - return c.cryptoDekUseCase, nil -} - // initMasterKeyChain loads the master key chain from environment variables. func (c *Container) initMasterKeyChain(ctx context.Context) (*cryptoDomain.MasterKeyChain, error) { // Get KMS service and logger @@ -186,33 +192,6 @@ func (c *Container) initKekUseCase(ctx context.Context) (cryptoUseCase.KekUseCas return cryptoUseCase.NewKekUseCase(txManager, kekRepository, keyManager), nil } -// initCryptoDekRepository creates the DEK repository for crypto use case. -func (c *Container) initCryptoDekRepository(ctx context.Context) (cryptoUseCase.DekRepository, error) { - db, err := c.DB(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get database: %w", err) - } - - return cryptoRepository.NewDekRepository(db), nil -} - -// initCryptoDekUseCase creates the DEK use case for the crypto module. -func (c *Container) initCryptoDekUseCase(ctx context.Context) (cryptoUseCase.DekUseCase, error) { - txManager, err := c.TxManager(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get tx manager: %w", err) - } - - dekRepo, err := c.CryptoDekRepository(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get crypto dek repository: %w", err) - } - - keyManager := c.KeyManager() - - return cryptoUseCase.NewDekUseCase(txManager, dekRepo, keyManager), nil -} - // loadKekChain loads all KEKs from the database and creates a KEK chain. func (c *Container) loadKekChain(ctx context.Context) (*cryptoDomain.KekChain, error) { kekUseCase, err := c.KekUseCase(ctx) diff --git a/internal/app/di_secrets.go b/internal/app/di_secrets.go index dea3fcb..0607a32 100644 --- a/internal/app/di_secrets.go +++ b/internal/app/di_secrets.go @@ -4,31 +4,11 @@ import ( "context" "fmt" - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoRepository "github.com/allisson/secrets/internal/crypto/repository" secretsHTTP "github.com/allisson/secrets/internal/secrets/http" secretsRepository "github.com/allisson/secrets/internal/secrets/repository" secretsUseCase "github.com/allisson/secrets/internal/secrets/usecase" ) -// DekRepository returns the DEK repository. -func (c *Container) DekRepository(ctx context.Context) (secretsUseCase.DekRepository, error) { - var err error - c.dekRepositoryInit.Do(func() { - c.dekRepository, err = c.initDekRepository(ctx) - if err != nil { - c.initErrors.Store("dekRepository", err) - } - }) - if err != nil { - return nil, err - } - if val, ok := c.initErrors.Load("dekRepository"); ok { - return nil, val.(error) - } - return c.dekRepository, nil -} - // SecretRepository returns the secret repository. func (c *Container) SecretRepository(ctx context.Context) (secretsUseCase.SecretRepository, error) { var err error @@ -83,17 +63,6 @@ func (c *Container) SecretHandler(ctx context.Context) (*secretsHTTP.SecretHandl return c.secretHandler, nil } -// initDekRepository creates the DEK repository. -func (c *Container) initDekRepository(ctx context.Context) (secretsUseCase.DekRepository, error) { - db, err := c.DB(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get database for dek repository: %w", err) - } - - return cryptoRepository.NewDekRepository(db), nil -} - -// initSecretRepository creates the secret repository. func (c *Container) initSecretRepository(ctx context.Context) (secretsUseCase.SecretRepository, error) { db, err := c.DB(ctx) if err != nil { @@ -103,16 +72,15 @@ func (c *Container) initSecretRepository(ctx context.Context) (secretsUseCase.Se return secretsRepository.NewSecretRepository(db), nil } -// initSecretUseCase creates the secret use case with all its dependencies. func (c *Container) initSecretUseCase(ctx context.Context) (secretsUseCase.SecretUseCase, error) { txManager, err := c.TxManager(ctx) if err != nil { return nil, fmt.Errorf("failed to get tx manager for secret use case: %w", err) } - dekRepository, err := c.DekRepository(ctx) + kr, err := c.Keyring(ctx) if err != nil { - return nil, fmt.Errorf("failed to get dek repository for secret use case: %w", err) + return nil, fmt.Errorf("failed to get keyring for secret use case: %w", err) } secretRepository, err := c.SecretRepository(ctx) @@ -120,26 +88,13 @@ func (c *Container) initSecretUseCase(ctx context.Context) (secretsUseCase.Secre return nil, fmt.Errorf("failed to get secret repository for secret use case: %w", err) } - kekChain, err := c.loadKekChain(ctx) - if err != nil { - return nil, fmt.Errorf("failed to load kek chain for secret use case: %w", err) - } - - aeadManager := c.AEADManager() - keyManager := c.KeyManager() - baseUseCase := secretsUseCase.NewSecretUseCase( txManager, - dekRepository, + kr, secretRepository, - kekChain, - aeadManager, - keyManager, - cryptoDomain.AESGCM, c.config.SecretValueSizeLimitBytes, ) - // Wrap with metrics if enabled if c.config.MetricsEnabled { businessMetrics, err := c.BusinessMetrics(ctx) if err != nil { @@ -151,7 +106,6 @@ func (c *Container) initSecretUseCase(ctx context.Context) (secretsUseCase.Secre return baseUseCase, nil } -// initSecretHandler creates the secret HTTP handler with all its dependencies. func (c *Container) initSecretHandler(ctx context.Context) (*secretsHTTP.SecretHandler, error) { secretUseCase, err := c.SecretUseCase(ctx) if err != nil { @@ -163,7 +117,5 @@ func (c *Container) initSecretHandler(ctx context.Context) (*secretsHTTP.SecretH return nil, fmt.Errorf("failed to get audit log use case for secret handler: %w", err) } - logger := c.Logger() - - return secretsHTTP.NewSecretHandler(secretUseCase, auditLogUseCase, logger), nil + return secretsHTTP.NewSecretHandler(secretUseCase, auditLogUseCase, c.Logger()), nil } diff --git a/internal/app/di_test.go b/internal/app/di_test.go index 2116ff4..5708d2c 100644 --- a/internal/app/di_test.go +++ b/internal/app/di_test.go @@ -348,50 +348,6 @@ func TestContainerKekUseCaseErrors(t *testing.T) { } } -// TestContainerCryptoDekRepositoryErrors verifies that Crypto DEK repository initialization errors are properly handled. -func TestContainerCryptoDekRepositoryErrors(t *testing.T) { - // Create a container with invalid database configuration - cfg := &config.Config{ - DBConnectionString: "", - } - - container := NewContainer(cfg) - - // Attempting to get Crypto DEK repository should return an error - _, err := container.CryptoDekRepository(context.Background()) - if err == nil { - t.Error("expected error when connecting with invalid config") - } - - // Attempting to get Crypto DEK repository again should return the same error - _, err2 := container.CryptoDekRepository(context.Background()) - if err2 == nil { - t.Error("expected error on second call to CryptoDekRepository()") - } -} - -// TestContainerCryptoDekUseCaseErrors verifies that Crypto DEK use case initialization errors are properly handled. -func TestContainerCryptoDekUseCaseErrors(t *testing.T) { - // Create a container with invalid database configuration - cfg := &config.Config{ - DBConnectionString: "", - } - - container := NewContainer(cfg) - - // Attempting to get Crypto DEK use case should return an error (due to DB error) - _, err := container.CryptoDekUseCase(context.Background()) - if err == nil { - t.Error("expected error when connecting with invalid config") - } - - // Attempting to get Crypto DEK use case again should return the same error - _, err2 := container.CryptoDekUseCase(context.Background()) - if err2 == nil { - t.Error("expected error on second call to CryptoDekUseCase()") - } -} - // TestContainerMasterKeyChain verifies that the master key chain can be retrieved from the container. func TestContainerMasterKeyChain(t *testing.T) { ctx := context.Background() @@ -748,9 +704,9 @@ func TestContainerSecretsComponents(t *testing.T) { // Since repositories need a DB, we expect errors if DB is not and cannot be connected - _, err := container.DekRepository(ctx) + _, err := container.Keyring(ctx) if err == nil { - t.Error("expected error for dek repository with invalid db config") + t.Error("expected error for keyring with invalid db config") } _, err = container.SecretRepository(ctx) @@ -783,11 +739,6 @@ func TestContainerTransitComponents(t *testing.T) { t.Error("expected error for transit key repository with invalid db config") } - _, err = container.TransitDekRepository(ctx) - if err == nil { - t.Error("expected error for transit dek repository with invalid db config") - } - _, err = container.TransitKeyUseCase(ctx) if err == nil { t.Error("expected error for transit key use case with invalid db config") @@ -823,11 +774,6 @@ func TestContainerTokenizationComponents(t *testing.T) { t.Error("expected error for tokenization token repository with invalid db config") } - _, err = container.TokenizationDekRepository(ctx) - if err == nil { - t.Error("expected error for tokenization dek repository with invalid db config") - } - _, err = container.TokenizationKeyUseCase(ctx) if err == nil { t.Error("expected error for tokenization key use case with invalid db config") diff --git a/internal/app/di_tokenization.go b/internal/app/di_tokenization.go index 1b8bac0..dfaba31 100644 --- a/internal/app/di_tokenization.go +++ b/internal/app/di_tokenization.go @@ -4,13 +4,11 @@ import ( "context" "fmt" - cryptoRepository "github.com/allisson/secrets/internal/crypto/repository" tokenizationHTTP "github.com/allisson/secrets/internal/tokenization/http" tokenizationRepository "github.com/allisson/secrets/internal/tokenization/repository" tokenizationUseCase "github.com/allisson/secrets/internal/tokenization/usecase" ) -// TokenizationKeyRepository returns the tokenization key repository. func (c *Container) TokenizationKeyRepository( ctx context.Context, ) (tokenizationUseCase.TokenizationKeyRepository, error) { @@ -30,7 +28,6 @@ func (c *Container) TokenizationKeyRepository( return c.tokenizationKeyRepository, nil } -// TokenizationTokenRepository returns the tokenization token repository. func (c *Container) TokenizationTokenRepository( ctx context.Context, ) (tokenizationUseCase.TokenRepository, error) { @@ -50,27 +47,6 @@ func (c *Container) TokenizationTokenRepository( return c.tokenizationTokenRepository, nil } -// TokenizationDekRepository returns the DEK repository for tokenization use case. -func (c *Container) TokenizationDekRepository( - ctx context.Context, -) (tokenizationUseCase.DekRepository, error) { - var err error - c.tokenizationDekRepositoryInit.Do(func() { - c.tokenizationDekRepository, err = c.initTokenizationDekRepository(ctx) - if err != nil { - c.initErrors.Store("tokenizationDekRepository", err) - } - }) - if err != nil { - return nil, err - } - if val, ok := c.initErrors.Load("tokenizationDekRepository"); ok { - return nil, val.(error) - } - return c.tokenizationDekRepository, nil -} - -// TokenizationKeyUseCase returns the tokenization key use case. func (c *Container) TokenizationKeyUseCase( ctx context.Context, ) (tokenizationUseCase.TokenizationKeyUseCase, error) { @@ -90,7 +66,6 @@ func (c *Container) TokenizationKeyUseCase( return c.tokenizationKeyUseCase, nil } -// TokenizationUseCase returns the tokenization use case. func (c *Container) TokenizationUseCase( ctx context.Context, ) (tokenizationUseCase.TokenizationUseCase, error) { @@ -110,7 +85,6 @@ func (c *Container) TokenizationUseCase( return c.tokenizationUseCase, nil } -// TokenizationKeyHandler returns the tokenization key HTTP handler. func (c *Container) TokenizationKeyHandler( ctx context.Context, ) (*tokenizationHTTP.TokenizationKeyHandler, error) { @@ -130,7 +104,6 @@ func (c *Container) TokenizationKeyHandler( return c.tokenizationKeyHandler, nil } -// TokenizationHandler returns the tokenization HTTP handler. func (c *Container) TokenizationHandler(ctx context.Context) (*tokenizationHTTP.TokenizationHandler, error) { var err error c.tokenizationHandlerInit.Do(func() { @@ -148,7 +121,6 @@ func (c *Container) TokenizationHandler(ctx context.Context) (*tokenizationHTTP. return c.tokenizationHandler, nil } -// initTokenizationKeyRepository creates the tokenization key repository. func (c *Container) initTokenizationKeyRepository( ctx context.Context, ) (tokenizationUseCase.TokenizationKeyRepository, error) { @@ -160,7 +132,6 @@ func (c *Container) initTokenizationKeyRepository( return tokenizationRepository.NewTokenizationKeyRepository(db), nil } -// initTokenizationTokenRepository creates the tokenization token repository. func (c *Container) initTokenizationTokenRepository( ctx context.Context, ) (tokenizationUseCase.TokenRepository, error) { @@ -172,19 +143,6 @@ func (c *Container) initTokenizationTokenRepository( return tokenizationRepository.NewTokenRepository(db), nil } -// initTokenizationDekRepository creates the DEK repository for tokenization use case. -func (c *Container) initTokenizationDekRepository( - ctx context.Context, -) (tokenizationUseCase.DekRepository, error) { - db, err := c.DB(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get database for tokenization dek repository: %w", err) - } - - return cryptoRepository.NewDekRepository(db), nil -} - -// initTokenizationKeyUseCase creates the tokenization key use case. func (c *Container) initTokenizationKeyUseCase( ctx context.Context, ) (tokenizationUseCase.TokenizationKeyUseCase, error) { @@ -201,28 +159,18 @@ func (c *Container) initTokenizationKeyUseCase( ) } - dekRepository, err := c.TokenizationDekRepository(ctx) + kr, err := c.Keyring(ctx) if err != nil { - return nil, fmt.Errorf("failed to get dek repository for tokenization key use case: %w", err) + return nil, fmt.Errorf("failed to get keyring for tokenization key use case: %w", err) } - kekChain, err := c.loadKekChain(ctx) - if err != nil { - return nil, fmt.Errorf("failed to load kek chain for tokenization key use case: %w", err) - } - - keyManager := c.KeyManager() - return tokenizationUseCase.NewTokenizationKeyUseCase( txManager, tokenizationKeyRepository, - dekRepository, - keyManager, - kekChain, + kr, ), nil } -// initTokenizationUseCase creates the tokenization use case. func (c *Container) initTokenizationUseCase( ctx context.Context, ) (tokenizationUseCase.TokenizationUseCase, error) { @@ -244,34 +192,21 @@ func (c *Container) initTokenizationUseCase( return nil, fmt.Errorf("failed to get token repository for tokenization use case: %w", err) } - dekRepository, err := c.TokenizationDekRepository(ctx) + kr, err := c.Keyring(ctx) if err != nil { - return nil, fmt.Errorf("failed to get dek repository for tokenization use case: %w", err) + return nil, fmt.Errorf("failed to get keyring for tokenization use case: %w", err) } - aeadManager := c.AEADManager() - - keyManager := c.KeyManager() - hashService := tokenizationUseCase.NewSHA256HashService() - kekChain, err := c.loadKekChain(ctx) - if err != nil { - return nil, fmt.Errorf("failed to load kek chain for tokenization use case: %w", err) - } - baseUseCase := tokenizationUseCase.NewTokenizationUseCase( txManager, tokenizationKeyRepository, tokenRepository, - dekRepository, - aeadManager, - keyManager, hashService, - kekChain, + kr, ) - // Wrap with metrics if enabled if c.config.MetricsEnabled { businessMetrics, err := c.BusinessMetrics(ctx) if err != nil { @@ -283,7 +218,6 @@ func (c *Container) initTokenizationUseCase( return baseUseCase, nil } -// initTokenizationKeyHandler creates the tokenization key HTTP handler. func (c *Container) initTokenizationKeyHandler( ctx context.Context, ) (*tokenizationHTTP.TokenizationKeyHandler, error) { @@ -295,25 +229,20 @@ func (c *Container) initTokenizationKeyHandler( ) } - logger := c.Logger() - - return tokenizationHTTP.NewTokenizationKeyHandler(tokenizationKeyUseCase, logger), nil + return tokenizationHTTP.NewTokenizationKeyHandler(tokenizationKeyUseCase, c.Logger()), nil } -// initTokenizationHandler creates the tokenization HTTP handler. func (c *Container) initTokenizationHandler( ctx context.Context, ) (*tokenizationHTTP.TokenizationHandler, error) { - tokenizationUseCase, err := c.TokenizationUseCase(ctx) + tokenizationUC, err := c.TokenizationUseCase(ctx) if err != nil { return nil, fmt.Errorf("failed to get tokenization use case for tokenization handler: %w", err) } - logger := c.Logger() - return tokenizationHTTP.NewTokenizationHandler( - tokenizationUseCase, + tokenizationUC, c.config.TokenizationBatchLimit, - logger, + c.Logger(), ), nil } diff --git a/internal/app/di_transit.go b/internal/app/di_transit.go index b050916..1fe038a 100644 --- a/internal/app/di_transit.go +++ b/internal/app/di_transit.go @@ -4,13 +4,11 @@ import ( "context" "fmt" - cryptoRepository "github.com/allisson/secrets/internal/crypto/repository" transitHTTP "github.com/allisson/secrets/internal/transit/http" transitRepository "github.com/allisson/secrets/internal/transit/repository" transitUseCase "github.com/allisson/secrets/internal/transit/usecase" ) -// TransitKeyRepository returns the transit key repository instance. func (c *Container) TransitKeyRepository(ctx context.Context) (transitUseCase.TransitKeyRepository, error) { var err error c.transitKeyRepositoryInit.Do(func() { @@ -28,25 +26,6 @@ func (c *Container) TransitKeyRepository(ctx context.Context) (transitUseCase.Tr return c.transitKeyRepository, nil } -// TransitDekRepository returns the DEK repository for transit use case. -func (c *Container) TransitDekRepository(ctx context.Context) (transitUseCase.DekRepository, error) { - var err error - c.transitDekRepositoryInit.Do(func() { - c.transitDekRepository, err = c.initTransitDekRepository(ctx) - if err != nil { - c.initErrors.Store("transitDekRepository", err) - } - }) - if err != nil { - return nil, err - } - if val, ok := c.initErrors.Load("transitDekRepository"); ok { - return nil, val.(error) - } - return c.transitDekRepository, nil -} - -// TransitKeyUseCase returns the transit key use case instance. func (c *Container) TransitKeyUseCase(ctx context.Context) (transitUseCase.TransitKeyUseCase, error) { var err error c.transitKeyUseCaseInit.Do(func() { @@ -64,7 +43,6 @@ func (c *Container) TransitKeyUseCase(ctx context.Context) (transitUseCase.Trans return c.transitKeyUseCase, nil } -// TransitKeyHandler returns the transit key HTTP handler instance. func (c *Container) TransitKeyHandler(ctx context.Context) (*transitHTTP.TransitKeyHandler, error) { var err error c.transitKeyHandlerInit.Do(func() { @@ -82,7 +60,6 @@ func (c *Container) TransitKeyHandler(ctx context.Context) (*transitHTTP.Transit return c.transitKeyHandler, nil } -// CryptoHandler returns the crypto HTTP handler instance. func (c *Container) CryptoHandler(ctx context.Context) (*transitHTTP.CryptoHandler, error) { var err error c.cryptoHandlerInit.Do(func() { @@ -100,7 +77,6 @@ func (c *Container) CryptoHandler(ctx context.Context) (*transitHTTP.CryptoHandl return c.cryptoHandler, nil } -// initTransitKeyRepository creates the transit key repository. func (c *Container) initTransitKeyRepository( ctx context.Context, ) (transitUseCase.TransitKeyRepository, error) { @@ -112,17 +88,6 @@ func (c *Container) initTransitKeyRepository( return transitRepository.NewTransitKeyRepository(db), nil } -// initTransitDekRepository creates the DEK repository for transit use case. -func (c *Container) initTransitDekRepository(ctx context.Context) (transitUseCase.DekRepository, error) { - db, err := c.DB(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get database for transit dek repository: %w", err) - } - - return cryptoRepository.NewDekRepository(db), nil -} - -// initTransitKeyUseCase creates the transit key use case with all its dependencies. func (c *Container) initTransitKeyUseCase(ctx context.Context) (transitUseCase.TransitKeyUseCase, error) { txManager, err := c.TxManager(ctx) if err != nil { @@ -134,29 +99,13 @@ func (c *Container) initTransitKeyUseCase(ctx context.Context) (transitUseCase.T return nil, fmt.Errorf("failed to get transit key repository for transit key use case: %w", err) } - dekRepository, err := c.TransitDekRepository(ctx) + kr, err := c.Keyring(ctx) if err != nil { - return nil, fmt.Errorf("failed to get dek repository for transit key use case: %w", err) + return nil, fmt.Errorf("failed to get keyring for transit key use case: %w", err) } - kekChain, err := c.loadKekChain(ctx) - if err != nil { - return nil, fmt.Errorf("failed to load kek chain for transit key use case: %w", err) - } - - keyManager := c.KeyManager() - aeadManager := c.AEADManager() - - baseUseCase := transitUseCase.NewTransitKeyUseCase( - txManager, - transitKeyRepository, - dekRepository, - keyManager, - aeadManager, - kekChain, - ) + baseUseCase := transitUseCase.NewTransitKeyUseCase(txManager, transitKeyRepository, kr) - // Wrap with metrics if enabled if c.config.MetricsEnabled { businessMetrics, err := c.BusinessMetrics(ctx) if err != nil { @@ -168,26 +117,20 @@ func (c *Container) initTransitKeyUseCase(ctx context.Context) (transitUseCase.T return baseUseCase, nil } -// initTransitKeyHandler creates the transit key HTTP handler with all its dependencies. func (c *Container) initTransitKeyHandler(ctx context.Context) (*transitHTTP.TransitKeyHandler, error) { transitKeyUseCase, err := c.TransitKeyUseCase(ctx) if err != nil { return nil, fmt.Errorf("failed to get transit key use case for transit key handler: %w", err) } - logger := c.Logger() - - return transitHTTP.NewTransitKeyHandler(transitKeyUseCase, logger), nil + return transitHTTP.NewTransitKeyHandler(transitKeyUseCase, c.Logger()), nil } -// initCryptoHandler creates the crypto HTTP handler with all its dependencies. func (c *Container) initCryptoHandler(ctx context.Context) (*transitHTTP.CryptoHandler, error) { transitKeyUseCase, err := c.TransitKeyUseCase(ctx) if err != nil { return nil, fmt.Errorf("failed to get transit key use case for crypto handler: %w", err) } - logger := c.Logger() - - return transitHTTP.NewCryptoHandler(transitKeyUseCase, logger), nil + return transitHTTP.NewCryptoHandler(transitKeyUseCase, c.Logger()), nil } diff --git a/internal/crypto/usecase/dek_usecase.go b/internal/crypto/usecase/dek_usecase.go deleted file mode 100644 index 925a29c..0000000 --- a/internal/crypto/usecase/dek_usecase.go +++ /dev/null @@ -1,102 +0,0 @@ -package usecase - -import ( - "context" - - "github.com/google/uuid" - - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoService "github.com/allisson/secrets/internal/crypto/service" - "github.com/allisson/secrets/internal/database" -) - -// dekUseCase implements business logic for Data Encryption Key management. -// Orchestrates DEK rewrapping during KEK rotation. -type dekUseCase struct { - txManager database.TxManager - dekRepo DekRepository - keyManager cryptoService.KeyManager -} - -// Rewrap finds DEKs that are not encrypted with the specified KEK ID, -// decrypts them using their old KEKs, and re-encrypts them with the new KEK. -// Returns the number of DEKs rewrapped in this batch. Executes in a transaction. -func (d *dekUseCase) Rewrap( - ctx context.Context, - kekChain *cryptoDomain.KekChain, - newKekID uuid.UUID, - batchSize int, -) (int, error) { - var rewrappedCount int - - err := d.txManager.WithTx(ctx, func(ctx context.Context) error { - // 1. Fetch batch of DEKs not using the new KEK ID - deks, err := d.dekRepo.GetBatchNotKekID(ctx, newKekID, batchSize) - if err != nil { - return err - } - - if len(deks) == 0 { - return nil - } - - // 2. Get the new KEK from the chain - newKek, ok := kekChain.Get(newKekID) - if !ok { - return cryptoDomain.ErrKekNotFound - } - if newKek.Key == nil { - return cryptoDomain.ErrDecryptionFailed - } - - // 3. Process each DEK in the batch - for _, dek := range deks { - // Get the old KEK - oldKek, ok := kekChain.Get(dek.KekID) - if !ok { - return cryptoDomain.ErrKekNotFound - } - - // Decrypt the DEK plaintext key using the old KEK - dekKey, err := d.keyManager.DecryptDek(dek, oldKek) - if err != nil { - return err - } - - // Encrypt the DEK plaintext key using the new KEK - encryptedKey, nonce, err := d.keyManager.EncryptDek(dekKey, newKek) - cryptoDomain.Zero(dekKey) - if err != nil { - return err - } - - // Update DEK entity - dek.KekID = newKekID - dek.EncryptedKey = encryptedKey - dek.Nonce = nonce - - // Save updated DEK - if err := d.dekRepo.Update(ctx, dek); err != nil { - return err - } - } - - rewrappedCount = len(deks) - return nil - }) - - return rewrappedCount, err -} - -// NewDekUseCase creates a new DekUseCase instance. -func NewDekUseCase( - txManager database.TxManager, - dekRepo DekRepository, - keyManager cryptoService.KeyManager, -) DekUseCase { - return &dekUseCase{ - txManager: txManager, - dekRepo: dekRepo, - keyManager: keyManager, - } -} diff --git a/internal/crypto/usecase/dek_usecase_test.go b/internal/crypto/usecase/dek_usecase_test.go deleted file mode 100644 index 38cd44f..0000000 --- a/internal/crypto/usecase/dek_usecase_test.go +++ /dev/null @@ -1,314 +0,0 @@ -package usecase_test - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoServiceMocks "github.com/allisson/secrets/internal/crypto/service/mocks" - "github.com/allisson/secrets/internal/crypto/usecase" - cryptoUsecaseMocks "github.com/allisson/secrets/internal/crypto/usecase/mocks" - dbMocks "github.com/allisson/secrets/internal/database/mocks" -) - -func TestDekUseCase_Rewrap(t *testing.T) { - t.Run("Success", func(t *testing.T) { - txManager := dbMocks.NewMockTxManager(t) - dekRepo := cryptoUsecaseMocks.NewMockDekRepository(t) - keyManager := cryptoServiceMocks.NewMockKeyManager(t) - useCase := usecase.NewDekUseCase(txManager, dekRepo, keyManager) - - ctx := context.Background() - newKekID := uuid.New() - oldKekID := uuid.New() - batchSize := 10 - - oldKek := &cryptoDomain.Kek{ - ID: oldKekID, - MasterKeyID: uuid.New().String(), - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("old-encrypted-key"), - Key: []byte("old-key"), - Nonce: []byte("old-nonce"), - Version: 1, - } - - newKek := &cryptoDomain.Kek{ - ID: newKekID, - MasterKeyID: uuid.New().String(), - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("new-encrypted-key"), - Key: []byte("new-key"), - Nonce: []byte("new-nonce"), - Version: 2, - } - - kekChain := cryptoDomain.NewKekChain([]*cryptoDomain.Kek{newKek, oldKek}) - - dek1 := &cryptoDomain.Dek{ - ID: uuid.New(), - KekID: oldKekID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("dek1-encrypted-old"), - Nonce: []byte("dek1-nonce-old"), - CreatedAt: time.Now(), - } - - batch := []*cryptoDomain.Dek{dek1} - - // Setup mock expectations for transaction - txManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(nil). - Once() - - dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) - - plainDek1Key := []byte("dek1-plaintext-key") - keyManager.EXPECT().DecryptDek(dek1, oldKek).Return(plainDek1Key, nil) - - newEncDek1 := []byte("dek1-encrypted-new") - newNonceDek1 := []byte("dek1-nonce-new") - keyManager.EXPECT().EncryptDek(plainDek1Key, newKek).Return(newEncDek1, newNonceDek1, nil) - - dekRepo.EXPECT().Update(mock.Anything, mock.MatchedBy(func(dek *cryptoDomain.Dek) bool { - return dek.ID == dek1.ID && - dek.KekID == newKekID && - string(dek.EncryptedKey) == "dek1-encrypted-new" && - string(dek.Nonce) == "dek1-nonce-new" - })).Return(nil) - - rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) - - assert.NoError(t, err) - assert.Equal(t, 1, rewrapped) - }) - - t.Run("Zero DEKs to rewrap", func(t *testing.T) { - txManager := dbMocks.NewMockTxManager(t) - dekRepo := cryptoUsecaseMocks.NewMockDekRepository(t) - keyManager := cryptoServiceMocks.NewMockKeyManager(t) - useCase := usecase.NewDekUseCase(txManager, dekRepo, keyManager) - - ctx := context.Background() - newKekID := uuid.New() - newKek := &cryptoDomain.Kek{ID: newKekID, Key: []byte("new-key")} - kekChain := cryptoDomain.NewKekChain([]*cryptoDomain.Kek{newKek}) - batchSize := 10 - - // Setup mock expectations for transaction - txManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(nil). - Once() - - dekRepo.EXPECT(). - GetBatchNotKekID(mock.Anything, newKekID, batchSize). - Return([]*cryptoDomain.Dek{}, nil) - - rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) - - assert.NoError(t, err) - assert.Equal(t, 0, rewrapped) - }) - - t.Run("New KEK not found in chain", func(t *testing.T) { - txManager := dbMocks.NewMockTxManager(t) - dekRepo := cryptoUsecaseMocks.NewMockDekRepository(t) - keyManager := cryptoServiceMocks.NewMockKeyManager(t) - useCase := usecase.NewDekUseCase(txManager, dekRepo, keyManager) - - ctx := context.Background() - newKekID := uuid.New() - kekChain := cryptoDomain.NewKekChain([]*cryptoDomain.Kek{{ID: uuid.New()}}) - batchSize := 10 - - dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: uuid.New()} - batch := []*cryptoDomain.Dek{dek1} - - // Setup mock expectations for transaction - txManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(cryptoDomain.ErrKekNotFound). - Once() - - dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) - - rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) - - assert.ErrorIs(t, err, cryptoDomain.ErrKekNotFound) - assert.Equal(t, 0, rewrapped) - }) - - t.Run("Old KEK not found in chain", func(t *testing.T) { - txManager := dbMocks.NewMockTxManager(t) - dekRepo := cryptoUsecaseMocks.NewMockDekRepository(t) - keyManager := cryptoServiceMocks.NewMockKeyManager(t) - useCase := usecase.NewDekUseCase(txManager, dekRepo, keyManager) - - ctx := context.Background() - newKekID := uuid.New() - oldKekID := uuid.New() - batchSize := 10 - - newKek := &cryptoDomain.Kek{ID: newKekID, Key: []byte("new-key")} - kekChain := cryptoDomain.NewKekChain([]*cryptoDomain.Kek{newKek}) - - dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: oldKekID} - batch := []*cryptoDomain.Dek{dek1} - - // Setup mock expectations for transaction - txManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(cryptoDomain.ErrKekNotFound). - Once() - - dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) - - rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) - - assert.ErrorIs(t, err, cryptoDomain.ErrKekNotFound) - assert.Equal(t, 0, rewrapped) - }) - - t.Run("DecryptDek error", func(t *testing.T) { - txManager := dbMocks.NewMockTxManager(t) - dekRepo := cryptoUsecaseMocks.NewMockDekRepository(t) - keyManager := cryptoServiceMocks.NewMockKeyManager(t) - useCase := usecase.NewDekUseCase(txManager, dekRepo, keyManager) - - ctx := context.Background() - newKekID := uuid.New() - oldKekID := uuid.New() - batchSize := 10 - - newKek := &cryptoDomain.Kek{ID: newKekID, Key: []byte("new-key")} - oldKek := &cryptoDomain.Kek{ID: oldKekID, Key: []byte("old-key")} - kekChain := cryptoDomain.NewKekChain([]*cryptoDomain.Kek{newKek, oldKek}) - - dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: oldKekID} - batch := []*cryptoDomain.Dek{dek1} - - expectedErr := errors.New("decryption failed") - - // Setup mock expectations for transaction - txManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(expectedErr). - Once() - - dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) - keyManager.EXPECT().DecryptDek(dek1, oldKek).Return(nil, expectedErr) - - rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) - - assert.ErrorIs(t, err, expectedErr) - assert.Equal(t, 0, rewrapped) - }) - - t.Run("EncryptDek error", func(t *testing.T) { - txManager := dbMocks.NewMockTxManager(t) - dekRepo := cryptoUsecaseMocks.NewMockDekRepository(t) - keyManager := cryptoServiceMocks.NewMockKeyManager(t) - useCase := usecase.NewDekUseCase(txManager, dekRepo, keyManager) - - ctx := context.Background() - newKekID := uuid.New() - oldKekID := uuid.New() - batchSize := 10 - - newKek := &cryptoDomain.Kek{ID: newKekID, Key: []byte("new-key")} - oldKek := &cryptoDomain.Kek{ID: oldKekID, Key: []byte("old-key")} - kekChain := cryptoDomain.NewKekChain([]*cryptoDomain.Kek{newKek, oldKek}) - - dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: oldKekID} - batch := []*cryptoDomain.Dek{dek1} - - expectedErr := errors.New("encryption failed") - - // Setup mock expectations for transaction - txManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(expectedErr). - Once() - - dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) - - plainDek1Key := []byte("plain-key") - keyManager.EXPECT().DecryptDek(dek1, oldKek).Return(plainDek1Key, nil) - - keyManager.EXPECT().EncryptDek(plainDek1Key, newKek).Return(nil, nil, expectedErr) - - rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) - - assert.ErrorIs(t, err, expectedErr) - assert.Equal(t, 0, rewrapped) - }) - - t.Run("Update dek error", func(t *testing.T) { - txManager := dbMocks.NewMockTxManager(t) - dekRepo := cryptoUsecaseMocks.NewMockDekRepository(t) - keyManager := cryptoServiceMocks.NewMockKeyManager(t) - useCase := usecase.NewDekUseCase(txManager, dekRepo, keyManager) - - ctx := context.Background() - newKekID := uuid.New() - oldKekID := uuid.New() - batchSize := 10 - - newKek := &cryptoDomain.Kek{ID: newKekID, Key: []byte("new-key")} - oldKek := &cryptoDomain.Kek{ID: oldKekID, Key: []byte("old-key")} - kekChain := cryptoDomain.NewKekChain([]*cryptoDomain.Kek{newKek, oldKek}) - - dek1 := &cryptoDomain.Dek{ID: uuid.New(), KekID: oldKekID} - batch := []*cryptoDomain.Dek{dek1} - - expectedErr := errors.New("update failed") - - // Setup mock expectations for transaction - txManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(expectedErr). - Once() - - dekRepo.EXPECT().GetBatchNotKekID(mock.Anything, newKekID, batchSize).Return(batch, nil) - - plainDek1Key := []byte("plain-key") - keyManager.EXPECT().DecryptDek(dek1, oldKek).Return(plainDek1Key, nil) - keyManager.EXPECT().EncryptDek(plainDek1Key, newKek).Return([]byte("enc"), []byte("nonce"), nil) - - dekRepo.EXPECT().Update(mock.Anything, dek1).Return(expectedErr) - - rewrapped, err := useCase.Rewrap(ctx, kekChain, newKekID, batchSize) - - assert.ErrorIs(t, err, expectedErr) - assert.Equal(t, 0, rewrapped) - }) -} diff --git a/internal/crypto/usecase/interface.go b/internal/crypto/usecase/interface.go index 40b8a42..e52931e 100644 --- a/internal/crypto/usecase/interface.go +++ b/internal/crypto/usecase/interface.go @@ -1,59 +1,26 @@ -// Package usecase defines business logic interfaces for KEK operations and repository contracts. +// Package usecase defines business logic interfaces for KEK operations. +// +// DEK envelope encryption now lives in internal/keyring. This package only +// retains KEK lifecycle operations driven by the rotation CLI. package usecase import ( "context" - "github.com/google/uuid" - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" ) // KekRepository defines persistence operations for Key Encryption Keys. // Implementations must support transaction-aware operations via context propagation. type KekRepository interface { - // Create stores a new KEK in the repository. Create(ctx context.Context, kek *cryptoDomain.Kek) error - - // Update modifies an existing KEK in the repository. Update(ctx context.Context, kek *cryptoDomain.Kek) error - - // List retrieves all KEKs ordered by version descending (newest first). List(ctx context.Context) ([]*cryptoDomain.Kek, error) } // KekUseCase defines business logic operations for Key Encryption Key management. -// It orchestrates KEK lifecycle including creation, rotation, and unwrapping. type KekUseCase interface { - // Create generates and persists a new KEK using the active master key. Create(ctx context.Context, masterKeyChain *cryptoDomain.MasterKeyChain, alg cryptoDomain.Algorithm) error - - // Rotate performs atomic KEK rotation by creating a new KEK with incremented version. Rotate(ctx context.Context, masterKeyChain *cryptoDomain.MasterKeyChain, alg cryptoDomain.Algorithm) error - - // Unwrap decrypts all KEKs from the database and returns them in a KekChain for in-memory use. Unwrap(ctx context.Context, masterKeyChain *cryptoDomain.MasterKeyChain) (*cryptoDomain.KekChain, error) } - -// DekRepository defines persistence operations for Data Encryption Keys. -// Implementations must support transaction-aware operations via context propagation. -type DekRepository interface { - // Update modifies an existing DEK in the repository. - Update(ctx context.Context, dek *cryptoDomain.Dek) error - - // GetBatchNotKekID retrieves a batch of DEKs that are not encrypted with the given KEK ID. - GetBatchNotKekID(ctx context.Context, kekID uuid.UUID, limit int) ([]*cryptoDomain.Dek, error) -} - -// DekUseCase defines business logic operations for Data Encryption Key management. -type DekUseCase interface { - // Rewrap finds DEKs that are not encrypted with the specified KEK ID, - // decrypts them using their old KEKs, and re-encrypts them with the new KEK. - // Returns the number of DEKs rewrapped in this batch. - Rewrap( - ctx context.Context, - kekChain *cryptoDomain.KekChain, - newKekID uuid.UUID, - batchSize int, - ) (int, error) -} diff --git a/internal/crypto/usecase/mocks/mocks.go b/internal/crypto/usecase/mocks/mocks.go index 4beffe4..fee23ba 100644 --- a/internal/crypto/usecase/mocks/mocks.go +++ b/internal/crypto/usecase/mocks/mocks.go @@ -8,7 +8,6 @@ import ( "context" "github.com/allisson/secrets/internal/crypto/domain" - "github.com/google/uuid" mock "github.com/stretchr/testify/mock" ) @@ -435,266 +434,3 @@ func (_c *MockKekUseCase_Unwrap_Call) RunAndReturn(run func(ctx context.Context, _c.Call.Return(run) return _c } - -// NewMockDekRepository creates a new instance of MockDekRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockDekRepository(t interface { - mock.TestingT - Cleanup(func()) -}) *MockDekRepository { - mock := &MockDekRepository{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockDekRepository is an autogenerated mock type for the DekRepository type -type MockDekRepository struct { - mock.Mock -} - -type MockDekRepository_Expecter struct { - mock *mock.Mock -} - -func (_m *MockDekRepository) EXPECT() *MockDekRepository_Expecter { - return &MockDekRepository_Expecter{mock: &_m.Mock} -} - -// GetBatchNotKekID provides a mock function for the type MockDekRepository -func (_mock *MockDekRepository) GetBatchNotKekID(ctx context.Context, kekID uuid.UUID, limit int) ([]*domain.Dek, error) { - ret := _mock.Called(ctx, kekID, limit) - - if len(ret) == 0 { - panic("no return value specified for GetBatchNotKekID") - } - - var r0 []*domain.Dek - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID, int) ([]*domain.Dek, error)); ok { - return returnFunc(ctx, kekID, limit) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID, int) []*domain.Dek); ok { - r0 = returnFunc(ctx, kekID, limit) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*domain.Dek) - } - } - if returnFunc, ok := ret.Get(1).(func(context.Context, uuid.UUID, int) error); ok { - r1 = returnFunc(ctx, kekID, limit) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockDekRepository_GetBatchNotKekID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBatchNotKekID' -type MockDekRepository_GetBatchNotKekID_Call struct { - *mock.Call -} - -// GetBatchNotKekID is a helper method to define mock.On call -// - ctx context.Context -// - kekID uuid.UUID -// - limit int -func (_e *MockDekRepository_Expecter) GetBatchNotKekID(ctx interface{}, kekID interface{}, limit interface{}) *MockDekRepository_GetBatchNotKekID_Call { - return &MockDekRepository_GetBatchNotKekID_Call{Call: _e.mock.On("GetBatchNotKekID", ctx, kekID, limit)} -} - -func (_c *MockDekRepository_GetBatchNotKekID_Call) Run(run func(ctx context.Context, kekID uuid.UUID, limit int)) *MockDekRepository_GetBatchNotKekID_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 uuid.UUID - if args[1] != nil { - arg1 = args[1].(uuid.UUID) - } - var arg2 int - if args[2] != nil { - arg2 = args[2].(int) - } - run( - arg0, - arg1, - arg2, - ) - }) - return _c -} - -func (_c *MockDekRepository_GetBatchNotKekID_Call) Return(deks []*domain.Dek, err error) *MockDekRepository_GetBatchNotKekID_Call { - _c.Call.Return(deks, err) - return _c -} - -func (_c *MockDekRepository_GetBatchNotKekID_Call) RunAndReturn(run func(ctx context.Context, kekID uuid.UUID, limit int) ([]*domain.Dek, error)) *MockDekRepository_GetBatchNotKekID_Call { - _c.Call.Return(run) - return _c -} - -// Update provides a mock function for the type MockDekRepository -func (_mock *MockDekRepository) Update(ctx context.Context, dek *domain.Dek) error { - ret := _mock.Called(ctx, dek) - - if len(ret) == 0 { - panic("no return value specified for Update") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.Dek) error); ok { - r0 = returnFunc(ctx, dek) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockDekRepository_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' -type MockDekRepository_Update_Call struct { - *mock.Call -} - -// Update is a helper method to define mock.On call -// - ctx context.Context -// - dek *domain.Dek -func (_e *MockDekRepository_Expecter) Update(ctx interface{}, dek interface{}) *MockDekRepository_Update_Call { - return &MockDekRepository_Update_Call{Call: _e.mock.On("Update", ctx, dek)} -} - -func (_c *MockDekRepository_Update_Call) Run(run func(ctx context.Context, dek *domain.Dek)) *MockDekRepository_Update_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 *domain.Dek - if args[1] != nil { - arg1 = args[1].(*domain.Dek) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockDekRepository_Update_Call) Return(err error) *MockDekRepository_Update_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockDekRepository_Update_Call) RunAndReturn(run func(ctx context.Context, dek *domain.Dek) error) *MockDekRepository_Update_Call { - _c.Call.Return(run) - return _c -} - -// NewMockDekUseCase creates a new instance of MockDekUseCase. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockDekUseCase(t interface { - mock.TestingT - Cleanup(func()) -}) *MockDekUseCase { - mock := &MockDekUseCase{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockDekUseCase is an autogenerated mock type for the DekUseCase type -type MockDekUseCase struct { - mock.Mock -} - -type MockDekUseCase_Expecter struct { - mock *mock.Mock -} - -func (_m *MockDekUseCase) EXPECT() *MockDekUseCase_Expecter { - return &MockDekUseCase_Expecter{mock: &_m.Mock} -} - -// Rewrap provides a mock function for the type MockDekUseCase -func (_mock *MockDekUseCase) Rewrap(ctx context.Context, kekChain *domain.KekChain, newKekID uuid.UUID, batchSize int) (int, error) { - ret := _mock.Called(ctx, kekChain, newKekID, batchSize) - - if len(ret) == 0 { - panic("no return value specified for Rewrap") - } - - var r0 int - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.KekChain, uuid.UUID, int) (int, error)); ok { - return returnFunc(ctx, kekChain, newKekID, batchSize) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.KekChain, uuid.UUID, int) int); ok { - r0 = returnFunc(ctx, kekChain, newKekID, batchSize) - } else { - r0 = ret.Get(0).(int) - } - if returnFunc, ok := ret.Get(1).(func(context.Context, *domain.KekChain, uuid.UUID, int) error); ok { - r1 = returnFunc(ctx, kekChain, newKekID, batchSize) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockDekUseCase_Rewrap_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Rewrap' -type MockDekUseCase_Rewrap_Call struct { - *mock.Call -} - -// Rewrap is a helper method to define mock.On call -// - ctx context.Context -// - kekChain *domain.KekChain -// - newKekID uuid.UUID -// - batchSize int -func (_e *MockDekUseCase_Expecter) Rewrap(ctx interface{}, kekChain interface{}, newKekID interface{}, batchSize interface{}) *MockDekUseCase_Rewrap_Call { - return &MockDekUseCase_Rewrap_Call{Call: _e.mock.On("Rewrap", ctx, kekChain, newKekID, batchSize)} -} - -func (_c *MockDekUseCase_Rewrap_Call) Run(run func(ctx context.Context, kekChain *domain.KekChain, newKekID uuid.UUID, batchSize int)) *MockDekUseCase_Rewrap_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 *domain.KekChain - if args[1] != nil { - arg1 = args[1].(*domain.KekChain) - } - var arg2 uuid.UUID - if args[2] != nil { - arg2 = args[2].(uuid.UUID) - } - var arg3 int - if args[3] != nil { - arg3 = args[3].(int) - } - run( - arg0, - arg1, - arg2, - arg3, - ) - }) - return _c -} - -func (_c *MockDekUseCase_Rewrap_Call) Return(n int, err error) *MockDekUseCase_Rewrap_Call { - _c.Call.Return(n, err) - return _c -} - -func (_c *MockDekUseCase_Rewrap_Call) RunAndReturn(run func(ctx context.Context, kekChain *domain.KekChain, newKekID uuid.UUID, batchSize int) (int, error)) *MockDekUseCase_Rewrap_Call { - _c.Call.Return(run) - return _c -} diff --git a/internal/keyring/fake.go b/internal/keyring/fake.go new file mode 100644 index 0000000..35f571b --- /dev/null +++ b/internal/keyring/fake.go @@ -0,0 +1,178 @@ +package keyring + +import ( + "context" + "encoding/binary" + "errors" + "sync" + + "github.com/google/uuid" +) + +// Fake is an in-memory Keyring suitable for feature unit tests. +// +// It does not perform real cryptography: ciphertext is a deterministic +// transformation of the plaintext keyed by the DekID. The point is to give +// features a real second adapter so that Keyring becomes a real seam, and to +// let feature tests assert behaviour (a value can be encrypted, persisted, +// and decrypted back) without touching the database or KMS. +// +// Concurrency-safe. +type Fake struct { + mu sync.Mutex + deks map[uuid.UUID]struct{} + nextDek uint64 + + // FailEncrypt, FailDecrypt, FailAllocate, FailRewrap, when non-nil, make + // the matching method return the stored error. Useful for failure-path + // tests in callers. + FailEncrypt error + FailDecrypt error + FailAllocate error + FailRewrap error +} + +// NewFake constructs an empty Fake. +func NewFake() *Fake { + return &Fake{deks: map[uuid.UUID]struct{}{}} +} + +// Encrypt returns an Envelope whose Ciphertext is a reversible XOR of the +// plaintext with a DekID-derived stream. The DekID is allocated and tracked. +func (f *Fake) Encrypt(_ context.Context, plaintext []byte) (Envelope, error) { + if f.FailEncrypt != nil { + return Envelope{}, f.FailEncrypt + } + + dekID := f.allocate() + return Envelope{ + DekID: dekID, + Ciphertext: xorStream(plaintext, dekID), + Nonce: nonceFor(dekID), + }, nil +} + +// Decrypt reverses Encrypt. Returns an error if the DekID was never allocated. +func (f *Fake) Decrypt(_ context.Context, env Envelope) ([]byte, error) { + if f.FailDecrypt != nil { + return nil, f.FailDecrypt + } + + if !f.knows(env.DekID) { + return nil, errors.New("keyring.Fake: unknown DekID") + } + return xorStream(env.Ciphertext, env.DekID), nil +} + +// AllocateDek returns a fresh handle. Algorithm is ignored. +func (f *Fake) AllocateDek(_ context.Context, _ Algorithm) (DekHandle, error) { + if f.FailAllocate != nil { + return DekHandle{}, f.FailAllocate + } + return DekHandle{DekID: f.allocate()}, nil +} + +// EncryptWith XORs plaintext with a stream derived from the handle's DekID +// and the optional aad. +func (f *Fake) EncryptWith( + _ context.Context, + handle DekHandle, + plaintext, aad []byte, +) (ciphertext, nonce []byte, err error) { + if f.FailEncrypt != nil { + return nil, nil, f.FailEncrypt + } + if !f.knows(handle.DekID) { + return nil, nil, errors.New("keyring.Fake: unknown DekID") + } + return xorStreamAAD(plaintext, handle.DekID, aad), nonceFor(handle.DekID), nil +} + +// DecryptWith reverses EncryptWith. +func (f *Fake) DecryptWith( + _ context.Context, + handle DekHandle, + ciphertext, _ []byte, + aad []byte, +) ([]byte, error) { + if f.FailDecrypt != nil { + return nil, f.FailDecrypt + } + if !f.knows(handle.DekID) { + return nil, errors.New("keyring.Fake: unknown DekID") + } + return xorStreamAAD(ciphertext, handle.DekID, aad), nil +} + +// Rewrap is a no-op for the Fake (there is no KEK chain) but honors +// FailRewrap and validates the DekID is known. +func (f *Fake) Rewrap(_ context.Context, dekID uuid.UUID) error { + if f.FailRewrap != nil { + return f.FailRewrap + } + if !f.knows(dekID) { + return errors.New("keyring.Fake: unknown DekID") + } + return nil +} + +// RewrapAll returns 0 by default for the Fake — there is no notion of a +// stale KEK in-memory. Tests that exercise the rotation worker should +// stub this behaviour by counting tracked DEKs. +func (f *Fake) RewrapAll(_ context.Context, _ int) (int, error) { + if f.FailRewrap != nil { + return 0, f.FailRewrap + } + return 0, nil +} + +// ActiveKekID returns the zero UUID for the Fake; the value is only used +// by the rotation worker to verify the operator-provided KEK ID matches. +func (f *Fake) ActiveKekID() uuid.UUID { + return uuid.Nil +} + +func (f *Fake) allocate() uuid.UUID { + f.mu.Lock() + defer f.mu.Unlock() + f.nextDek++ + var id uuid.UUID + // Spread the counter across both halves so XOR-stream is non-zero + // everywhere (real DekIDs are UUIDv7, with high-entropy across all bytes). + binary.BigEndian.PutUint64(id[0:], f.nextDek^0x9e3779b97f4a7c15) + binary.BigEndian.PutUint64(id[8:], f.nextDek) + f.deks[id] = struct{}{} + return id +} + +func (f *Fake) knows(id uuid.UUID) bool { + f.mu.Lock() + defer f.mu.Unlock() + _, ok := f.deks[id] + return ok +} + +func xorStream(data []byte, dekID uuid.UUID) []byte { + out := make([]byte, len(data)) + for i := range data { + out[i] = data[i] ^ dekID[i%len(dekID)] + } + return out +} + +func xorStreamAAD(data []byte, dekID uuid.UUID, aad []byte) []byte { + out := xorStream(data, dekID) + for i := range out { + if len(aad) == 0 { + break + } + out[i] ^= aad[i%len(aad)] + } + return out +} + +func nonceFor(dekID uuid.UUID) []byte { + n := make([]byte, 12) + copy(n, dekID[:]) + return n +} diff --git a/internal/keyring/fake_test.go b/internal/keyring/fake_test.go new file mode 100644 index 0000000..858b60f --- /dev/null +++ b/internal/keyring/fake_test.go @@ -0,0 +1,102 @@ +package keyring_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/allisson/secrets/internal/keyring" +) + +func TestFake_EncryptDecrypt_RoundTrip(t *testing.T) { + f := keyring.NewFake() + ctx := context.Background() + + plaintext := []byte("hello world") + + env, err := f.Encrypt(ctx, plaintext) + require.NoError(t, err) + assert.NotEqual(t, plaintext, env.Ciphertext) + assert.NotEqual(t, [16]byte{}, [16]byte(env.DekID)) + + got, err := f.Decrypt(ctx, env) + require.NoError(t, err) + assert.Equal(t, plaintext, got) +} + +func TestFake_Decrypt_UnknownDekID(t *testing.T) { + f := keyring.NewFake() + other := keyring.NewFake() + ctx := context.Background() + + env, err := other.Encrypt(ctx, []byte("x")) + require.NoError(t, err) + + _, err = f.Decrypt(ctx, env) + assert.Error(t, err) +} + +func TestFake_AllocateDek_EncryptWith_DecryptWith(t *testing.T) { + f := keyring.NewFake() + ctx := context.Background() + + handle, err := f.AllocateDek(ctx, keyring.AESGCM) + require.NoError(t, err) + + plaintext := []byte("payload") + aad := []byte("ctx") + + ciphertext, nonce, err := f.EncryptWith(ctx, handle, plaintext, aad) + require.NoError(t, err) + assert.NotEqual(t, plaintext, ciphertext) + assert.NotEmpty(t, nonce) + + got, err := f.DecryptWith(ctx, handle, ciphertext, nonce, aad) + require.NoError(t, err) + assert.Equal(t, plaintext, got) +} + +func TestFake_Rewrap_KnownDek(t *testing.T) { + f := keyring.NewFake() + ctx := context.Background() + + env, err := f.Encrypt(ctx, []byte("x")) + require.NoError(t, err) + + require.NoError(t, f.Rewrap(ctx, env.DekID)) + + got, err := f.Decrypt(ctx, env) + require.NoError(t, err) + assert.Equal(t, []byte("x"), got) +} + +func TestFake_FailureInjection(t *testing.T) { + f := keyring.NewFake() + ctx := context.Background() + + boom := errors.New("boom") + + f.FailEncrypt = boom + _, err := f.Encrypt(ctx, []byte("x")) + assert.ErrorIs(t, err, boom) + f.FailEncrypt = nil + + env, err := f.Encrypt(ctx, []byte("x")) + require.NoError(t, err) + + f.FailDecrypt = boom + _, err = f.Decrypt(ctx, env) + assert.ErrorIs(t, err, boom) + f.FailDecrypt = nil + + f.FailAllocate = boom + _, err = f.AllocateDek(ctx, keyring.AESGCM) + assert.ErrorIs(t, err, boom) + f.FailAllocate = nil + + f.FailRewrap = boom + assert.ErrorIs(t, f.Rewrap(ctx, env.DekID), boom) +} diff --git a/internal/keyring/impl.go b/internal/keyring/impl.go new file mode 100644 index 0000000..dbd5b34 --- /dev/null +++ b/internal/keyring/impl.go @@ -0,0 +1,280 @@ +package keyring + +import ( + "context" + "errors" + "time" + + "github.com/google/uuid" + + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" + cryptoService "github.com/allisson/secrets/internal/crypto/service" +) + +var errKeyringBadBatchSize = errors.New("keyring: batch size must be positive") + +// dekStore is the persistence contract Keyring relies on for DEK rows. +// +// Implemented by *crypto/repository.DekRepository today. Kept narrow here so +// keyring can swap to its own repository once internal/crypto is folded in. +type dekStore interface { + Create(ctx context.Context, dek *cryptoDomain.Dek) error + Get(ctx context.Context, dekID uuid.UUID) (*cryptoDomain.Dek, error) + Update(ctx context.Context, dek *cryptoDomain.Dek) error + GetBatchNotKekID(ctx context.Context, kekID uuid.UUID, limit int) ([]*cryptoDomain.Dek, error) +} + +// keyring is the production Keyring. It orchestrates KEK chain lookup, DEK +// creation/persistence, AEAD cipher construction, and ciphertext I/O. +type keyring struct { + kekChain *cryptoDomain.KekChain + dekStore dekStore + aeadManager cryptoService.AEADManager + keyManager cryptoService.KeyManager + dekAlgorithm Algorithm +} + +// New constructs a Keyring with the given dependencies. The kekChain must be +// non-empty and have at least one usable KEK (the active one). +func New( + kekChain *cryptoDomain.KekChain, + dekStore dekStore, + aeadManager cryptoService.AEADManager, + keyManager cryptoService.KeyManager, + dekAlgorithm Algorithm, +) Keyring { + return &keyring{ + kekChain: kekChain, + dekStore: dekStore, + aeadManager: aeadManager, + keyManager: keyManager, + dekAlgorithm: dekAlgorithm, + } +} + +func (k *keyring) Encrypt(ctx context.Context, plaintext []byte) (Envelope, error) { + kek, err := k.activeKek() + if err != nil { + return Envelope{}, err + } + + dek, err := k.createAndPersistDek(ctx, kek, k.dekAlgorithm) + if err != nil { + return Envelope{}, err + } + + dekKey, err := k.keyManager.DecryptDek(&dek, kek) + if err != nil { + return Envelope{}, err + } + defer cryptoDomain.Zero(dekKey) + + cipher, err := k.aeadManager.CreateCipher(dekKey, k.dekAlgorithm) + if err != nil { + return Envelope{}, err + } + + ciphertext, nonce, err := cipher.Encrypt(plaintext, nil) + if err != nil { + return Envelope{}, err + } + + return Envelope{ + DekID: dek.ID, + Ciphertext: ciphertext, + Nonce: nonce, + }, nil +} + +func (k *keyring) Decrypt(ctx context.Context, env Envelope) ([]byte, error) { + dek, err := k.dekStore.Get(ctx, env.DekID) + if err != nil { + return nil, err + } + + kek, ok := k.kekChain.Get(dek.KekID) + if !ok { + return nil, cryptoDomain.ErrKekNotFound + } + + dekKey, err := k.keyManager.DecryptDek(dek, kek) + if err != nil { + return nil, err + } + defer cryptoDomain.Zero(dekKey) + + cipher, err := k.aeadManager.CreateCipher(dekKey, dek.Algorithm) + if err != nil { + return nil, err + } + + return cipher.Decrypt(env.Ciphertext, env.Nonce, nil) +} + +func (k *keyring) AllocateDek(ctx context.Context, alg Algorithm) (DekHandle, error) { + kek, err := k.activeKek() + if err != nil { + return DekHandle{}, err + } + + dek, err := k.createAndPersistDek(ctx, kek, alg) + if err != nil { + return DekHandle{}, err + } + + return DekHandle{DekID: dek.ID}, nil +} + +func (k *keyring) EncryptWith( + ctx context.Context, + handle DekHandle, + plaintext, aad []byte, +) (ciphertext, nonce []byte, err error) { + cipher, cleanup, err := k.openCipher(ctx, handle) + if err != nil { + return nil, nil, err + } + defer cleanup() + + return cipher.Encrypt(plaintext, aad) +} + +func (k *keyring) DecryptWith( + ctx context.Context, + handle DekHandle, + ciphertext, nonce, aad []byte, +) ([]byte, error) { + cipher, cleanup, err := k.openCipher(ctx, handle) + if err != nil { + return nil, err + } + defer cleanup() + + return cipher.Decrypt(ciphertext, nonce, aad) +} + +func (k *keyring) Rewrap(ctx context.Context, dekID uuid.UUID) error { + dek, err := k.dekStore.Get(ctx, dekID) + if err != nil { + return err + } + + activeKek, err := k.activeKek() + if err != nil { + return err + } + + if dek.KekID == activeKek.ID { + return nil + } + + oldKek, ok := k.kekChain.Get(dek.KekID) + if !ok { + return cryptoDomain.ErrKekNotFound + } + + dekKey, err := k.keyManager.DecryptDek(dek, oldKek) + if err != nil { + return err + } + defer cryptoDomain.Zero(dekKey) + + newEncKey, newNonce, err := k.keyManager.EncryptDek(dekKey, activeKek) + if err != nil { + return err + } + + dek.KekID = activeKek.ID + dek.EncryptedKey = newEncKey + dek.Nonce = newNonce + + return k.dekStore.Update(ctx, dek) +} + +func (k *keyring) RewrapAll(ctx context.Context, batchSize int) (int, error) { + if batchSize <= 0 { + return 0, errKeyringBadBatchSize + } + + activeKekID := k.kekChain.ActiveKekID() + total := 0 + for { + deks, err := k.dekStore.GetBatchNotKekID(ctx, activeKekID, batchSize) + if err != nil { + return total, err + } + if len(deks) == 0 { + return total, nil + } + for _, dek := range deks { + if err := k.Rewrap(ctx, dek.ID); err != nil { + return total, err + } + total++ + } + } +} + +func (k *keyring) ActiveKekID() uuid.UUID { + return k.kekChain.ActiveKekID() +} + +func (k *keyring) activeKek() (*cryptoDomain.Kek, error) { + kek, ok := k.kekChain.Get(k.kekChain.ActiveKekID()) + if !ok { + return nil, cryptoDomain.ErrKekNotFound + } + return kek, nil +} + +func (k *keyring) createAndPersistDek( + ctx context.Context, + kek *cryptoDomain.Kek, + alg Algorithm, +) (cryptoDomain.Dek, error) { + dek, err := k.keyManager.CreateDek(kek, alg) + if err != nil { + return cryptoDomain.Dek{}, err + } + + if dek.CreatedAt.IsZero() { + dek.CreatedAt = time.Now().UTC() + } + + if err := k.dekStore.Create(ctx, &dek); err != nil { + return cryptoDomain.Dek{}, err + } + + return dek, nil +} + +// openCipher loads the DEK referenced by handle, unwraps it under its KEK, +// and returns an AEAD cipher plus a cleanup function that zeroes the DEK. +func (k *keyring) openCipher( + ctx context.Context, + handle DekHandle, +) (cryptoService.AEAD, func(), error) { + dek, err := k.dekStore.Get(ctx, handle.DekID) + if err != nil { + return nil, func() {}, err + } + + kek, ok := k.kekChain.Get(dek.KekID) + if !ok { + return nil, func() {}, cryptoDomain.ErrKekNotFound + } + + dekKey, err := k.keyManager.DecryptDek(dek, kek) + if err != nil { + return nil, func() {}, err + } + + cipher, err := k.aeadManager.CreateCipher(dekKey, dek.Algorithm) + if err != nil { + cryptoDomain.Zero(dekKey) + return nil, func() {}, err + } + + cleanup := func() { cryptoDomain.Zero(dekKey) } + return cipher, cleanup, nil +} diff --git a/internal/keyring/keyring.go b/internal/keyring/keyring.go new file mode 100644 index 0000000..9c9507e --- /dev/null +++ b/internal/keyring/keyring.go @@ -0,0 +1,99 @@ +// Package keyring provides envelope encryption as a single deep module. +// +// Callers exchange plaintext for an Envelope (DekID + Ciphertext + Nonce) and +// back. The KEK chain, DEK lifecycle, AEAD cipher selection, and KMS-rooted +// master key chain all live behind this interface. +// +// Two encryption shapes are supported: +// +// - Fresh-DEK envelope (Encrypt/Decrypt) — a new DEK is created for each +// call. Used by the secrets and tokenization features. +// - Persistent DEK (AllocateDek/EncryptWith/DecryptWith) — one DEK is +// allocated once and reused across many encrypt/decrypt calls. Used by +// the transit feature where a named key wraps user payloads repeatedly. +// +// All methods honor an ambient transaction propagated via context; persistence +// joins the caller's tx when one is present (see ADR-0005). +package keyring + +import ( + "context" + + "github.com/google/uuid" + + cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" +) + +// Algorithm is the AEAD algorithm used to wrap a DEK and to encrypt under it. +// Re-exported from internal/crypto/domain so callers do not need that import. +type Algorithm = cryptoDomain.Algorithm + +const ( + // AESGCM is AES-256-GCM (optimal with AES-NI). + AESGCM = cryptoDomain.AESGCM + + // ChaCha20 is ChaCha20-Poly1305 (optimal without AES-NI). + ChaCha20 = cryptoDomain.ChaCha20 +) + +// Envelope is the result of Keyring.Encrypt and the input to Keyring.Decrypt. +// Callers persist the three fields exactly as they would today; nothing about +// the KEK or DEK material is exposed. +type Envelope struct { + DekID uuid.UUID + Ciphertext []byte + Nonce []byte +} + +// DekHandle is an opaque reference to a persistent DEK allocated via +// AllocateDek. Callers store only the DekID and reload the handle on demand. +type DekHandle struct { + DekID uuid.UUID +} + +// Keyring is the single seam features use to encrypt and decrypt data. +// +// Implementations must be safe for concurrent use. +type Keyring interface { + // Encrypt creates a fresh DEK, persists it, and uses it to encrypt + // plaintext exactly once. The returned Envelope contains everything a + // future Decrypt call needs. + Encrypt(ctx context.Context, plaintext []byte) (Envelope, error) + + // Decrypt reverses Encrypt for an Envelope produced by this Keyring. + Decrypt(ctx context.Context, env Envelope) ([]byte, error) + + // AllocateDek creates and persists a new DEK and returns a handle to it. + // The DEK is wrapped under the active KEK with the given algorithm. + AllocateDek(ctx context.Context, alg Algorithm) (DekHandle, error) + + // EncryptWith encrypts plaintext under the DEK identified by handle and + // authenticates the optional aad. Reuses the DEK across many calls. + EncryptWith( + ctx context.Context, + handle DekHandle, + plaintext, aad []byte, + ) (ciphertext, nonce []byte, err error) + + // DecryptWith reverses EncryptWith for ciphertext produced under the + // same DekHandle with the same aad. + DecryptWith( + ctx context.Context, + handle DekHandle, + ciphertext, nonce, aad []byte, + ) ([]byte, error) + + // Rewrap re-encrypts an existing DEK under the active KEK without + // changing the underlying key material. Used by KEK rotation. + Rewrap(ctx context.Context, dekID uuid.UUID) error + + // RewrapAll re-encrypts, in batches, every DEK not already under the + // active KEK. Returns the total number of DEKs rewrapped across all + // batches. Used by the KEK rotation worker. + RewrapAll(ctx context.Context, batchSize int) (int, error) + + // ActiveKekID returns the ID of the KEK currently used for new + // encryption. Exposed so the rotation worker can confirm it is + // targeting the same KEK the application is using. + ActiveKekID() uuid.UUID +} diff --git a/internal/secrets/usecase/interface.go b/internal/secrets/usecase/interface.go index 387827d..878d994 100644 --- a/internal/secrets/usecase/interface.go +++ b/internal/secrets/usecase/interface.go @@ -1,27 +1,15 @@ // Package usecase defines the interfaces and implementations for secret management use cases. -// Use cases orchestrate operations between repositories and services to implement business -// logic for managing encrypted secrets with automatic versioning. +// Use cases orchestrate operations between the keyring and the secret repository to +// implement business logic for managing encrypted secrets with automatic versioning. package usecase import ( "context" "time" - "github.com/google/uuid" - - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" secretsDomain "github.com/allisson/secrets/internal/secrets/domain" ) -// DekRepository defines the interface for Data Encryption Key persistence operations. -type DekRepository interface { - // Create stores a new DEK in the repository using transaction support from context. - Create(ctx context.Context, dek *cryptoDomain.Dek) error - - // Get retrieves a DEK by its ID. Returns ErrDekNotFound if not found. - Get(ctx context.Context, dekID uuid.UUID) (*cryptoDomain.Dek, error) -} - // SecretRepository defines the interface for Secret persistence operations. type SecretRepository interface { // Create stores a new secret in the repository using transaction support from context. @@ -52,7 +40,7 @@ type SecretRepository interface { // SecretUseCase defines the interface for secret management business logic. type SecretUseCase interface { // CreateOrUpdate creates a new secret or increments the version if path exists. - // Encrypts the value with a new DEK for each version. Returns the created/updated secret. + // Encrypts the value with a fresh DEK via the keyring on each call. CreateOrUpdate(ctx context.Context, path string, value []byte) (*secretsDomain.Secret, error) // Get retrieves and decrypts a secret by its path (latest version). diff --git a/internal/secrets/usecase/mocks/mocks.go b/internal/secrets/usecase/mocks/mocks.go index ad0282f..87578ba 100644 --- a/internal/secrets/usecase/mocks/mocks.go +++ b/internal/secrets/usecase/mocks/mocks.go @@ -8,164 +8,10 @@ import ( "context" "time" - "github.com/allisson/secrets/internal/crypto/domain" - domain0 "github.com/allisson/secrets/internal/secrets/domain" - "github.com/google/uuid" + "github.com/allisson/secrets/internal/secrets/domain" mock "github.com/stretchr/testify/mock" ) -// NewMockDekRepository creates a new instance of MockDekRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockDekRepository(t interface { - mock.TestingT - Cleanup(func()) -}) *MockDekRepository { - mock := &MockDekRepository{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockDekRepository is an autogenerated mock type for the DekRepository type -type MockDekRepository struct { - mock.Mock -} - -type MockDekRepository_Expecter struct { - mock *mock.Mock -} - -func (_m *MockDekRepository) EXPECT() *MockDekRepository_Expecter { - return &MockDekRepository_Expecter{mock: &_m.Mock} -} - -// Create provides a mock function for the type MockDekRepository -func (_mock *MockDekRepository) Create(ctx context.Context, dek *domain.Dek) error { - ret := _mock.Called(ctx, dek) - - if len(ret) == 0 { - panic("no return value specified for Create") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.Dek) error); ok { - r0 = returnFunc(ctx, dek) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockDekRepository_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' -type MockDekRepository_Create_Call struct { - *mock.Call -} - -// Create is a helper method to define mock.On call -// - ctx context.Context -// - dek *domain.Dek -func (_e *MockDekRepository_Expecter) Create(ctx interface{}, dek interface{}) *MockDekRepository_Create_Call { - return &MockDekRepository_Create_Call{Call: _e.mock.On("Create", ctx, dek)} -} - -func (_c *MockDekRepository_Create_Call) Run(run func(ctx context.Context, dek *domain.Dek)) *MockDekRepository_Create_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 *domain.Dek - if args[1] != nil { - arg1 = args[1].(*domain.Dek) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockDekRepository_Create_Call) Return(err error) *MockDekRepository_Create_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockDekRepository_Create_Call) RunAndReturn(run func(ctx context.Context, dek *domain.Dek) error) *MockDekRepository_Create_Call { - _c.Call.Return(run) - return _c -} - -// Get provides a mock function for the type MockDekRepository -func (_mock *MockDekRepository) Get(ctx context.Context, dekID uuid.UUID) (*domain.Dek, error) { - ret := _mock.Called(ctx, dekID) - - if len(ret) == 0 { - panic("no return value specified for Get") - } - - var r0 *domain.Dek - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*domain.Dek, error)); ok { - return returnFunc(ctx, dekID) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) *domain.Dek); ok { - r0 = returnFunc(ctx, dekID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain.Dek) - } - } - if returnFunc, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { - r1 = returnFunc(ctx, dekID) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockDekRepository_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' -type MockDekRepository_Get_Call struct { - *mock.Call -} - -// Get is a helper method to define mock.On call -// - ctx context.Context -// - dekID uuid.UUID -func (_e *MockDekRepository_Expecter) Get(ctx interface{}, dekID interface{}) *MockDekRepository_Get_Call { - return &MockDekRepository_Get_Call{Call: _e.mock.On("Get", ctx, dekID)} -} - -func (_c *MockDekRepository_Get_Call) Run(run func(ctx context.Context, dekID uuid.UUID)) *MockDekRepository_Get_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 uuid.UUID - if args[1] != nil { - arg1 = args[1].(uuid.UUID) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockDekRepository_Get_Call) Return(dek *domain.Dek, err error) *MockDekRepository_Get_Call { - _c.Call.Return(dek, err) - return _c -} - -func (_c *MockDekRepository_Get_Call) RunAndReturn(run func(ctx context.Context, dekID uuid.UUID) (*domain.Dek, error)) *MockDekRepository_Get_Call { - _c.Call.Return(run) - return _c -} - // NewMockSecretRepository creates a new instance of MockSecretRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockSecretRepository(t interface { @@ -194,7 +40,7 @@ func (_m *MockSecretRepository) EXPECT() *MockSecretRepository_Expecter { } // Create provides a mock function for the type MockSecretRepository -func (_mock *MockSecretRepository) Create(ctx context.Context, secret *domain0.Secret) error { +func (_mock *MockSecretRepository) Create(ctx context.Context, secret *domain.Secret) error { ret := _mock.Called(ctx, secret) if len(ret) == 0 { @@ -202,7 +48,7 @@ func (_mock *MockSecretRepository) Create(ctx context.Context, secret *domain0.S } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain0.Secret) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.Secret) error); ok { r0 = returnFunc(ctx, secret) } else { r0 = ret.Error(0) @@ -217,20 +63,20 @@ type MockSecretRepository_Create_Call struct { // Create is a helper method to define mock.On call // - ctx context.Context -// - secret *domain0.Secret +// - secret *domain.Secret func (_e *MockSecretRepository_Expecter) Create(ctx interface{}, secret interface{}) *MockSecretRepository_Create_Call { return &MockSecretRepository_Create_Call{Call: _e.mock.On("Create", ctx, secret)} } -func (_c *MockSecretRepository_Create_Call) Run(run func(ctx context.Context, secret *domain0.Secret)) *MockSecretRepository_Create_Call { +func (_c *MockSecretRepository_Create_Call) Run(run func(ctx context.Context, secret *domain.Secret)) *MockSecretRepository_Create_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 *domain0.Secret + var arg1 *domain.Secret if args[1] != nil { - arg1 = args[1].(*domain0.Secret) + arg1 = args[1].(*domain.Secret) } run( arg0, @@ -245,7 +91,7 @@ func (_c *MockSecretRepository_Create_Call) Return(err error) *MockSecretReposit return _c } -func (_c *MockSecretRepository_Create_Call) RunAndReturn(run func(ctx context.Context, secret *domain0.Secret) error) *MockSecretRepository_Create_Call { +func (_c *MockSecretRepository_Create_Call) RunAndReturn(run func(ctx context.Context, secret *domain.Secret) error) *MockSecretRepository_Create_Call { _c.Call.Return(run) return _c } @@ -308,23 +154,23 @@ func (_c *MockSecretRepository_Delete_Call) RunAndReturn(run func(ctx context.Co } // GetByPath provides a mock function for the type MockSecretRepository -func (_mock *MockSecretRepository) GetByPath(ctx context.Context, path string) (*domain0.Secret, error) { +func (_mock *MockSecretRepository) GetByPath(ctx context.Context, path string) (*domain.Secret, error) { ret := _mock.Called(ctx, path) if len(ret) == 0 { panic("no return value specified for GetByPath") } - var r0 *domain0.Secret + var r0 *domain.Secret var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain0.Secret, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain.Secret, error)); ok { return returnFunc(ctx, path) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain0.Secret); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain.Secret); ok { r0 = returnFunc(ctx, path) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.Secret) + r0 = ret.Get(0).(*domain.Secret) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { @@ -365,34 +211,34 @@ func (_c *MockSecretRepository_GetByPath_Call) Run(run func(ctx context.Context, return _c } -func (_c *MockSecretRepository_GetByPath_Call) Return(secret *domain0.Secret, err error) *MockSecretRepository_GetByPath_Call { +func (_c *MockSecretRepository_GetByPath_Call) Return(secret *domain.Secret, err error) *MockSecretRepository_GetByPath_Call { _c.Call.Return(secret, err) return _c } -func (_c *MockSecretRepository_GetByPath_Call) RunAndReturn(run func(ctx context.Context, path string) (*domain0.Secret, error)) *MockSecretRepository_GetByPath_Call { +func (_c *MockSecretRepository_GetByPath_Call) RunAndReturn(run func(ctx context.Context, path string) (*domain.Secret, error)) *MockSecretRepository_GetByPath_Call { _c.Call.Return(run) return _c } // GetByPathAndVersion provides a mock function for the type MockSecretRepository -func (_mock *MockSecretRepository) GetByPathAndVersion(ctx context.Context, path string, version uint) (*domain0.Secret, error) { +func (_mock *MockSecretRepository) GetByPathAndVersion(ctx context.Context, path string, version uint) (*domain.Secret, error) { ret := _mock.Called(ctx, path, version) if len(ret) == 0 { panic("no return value specified for GetByPathAndVersion") } - var r0 *domain0.Secret + var r0 *domain.Secret var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain0.Secret, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain.Secret, error)); ok { return returnFunc(ctx, path, version) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain0.Secret); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain.Secret); ok { r0 = returnFunc(ctx, path, version) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.Secret) + r0 = ret.Get(0).(*domain.Secret) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) error); ok { @@ -439,12 +285,12 @@ func (_c *MockSecretRepository_GetByPathAndVersion_Call) Run(run func(ctx contex return _c } -func (_c *MockSecretRepository_GetByPathAndVersion_Call) Return(secret *domain0.Secret, err error) *MockSecretRepository_GetByPathAndVersion_Call { +func (_c *MockSecretRepository_GetByPathAndVersion_Call) Return(secret *domain.Secret, err error) *MockSecretRepository_GetByPathAndVersion_Call { _c.Call.Return(secret, err) return _c } -func (_c *MockSecretRepository_GetByPathAndVersion_Call) RunAndReturn(run func(ctx context.Context, path string, version uint) (*domain0.Secret, error)) *MockSecretRepository_GetByPathAndVersion_Call { +func (_c *MockSecretRepository_GetByPathAndVersion_Call) RunAndReturn(run func(ctx context.Context, path string, version uint) (*domain.Secret, error)) *MockSecretRepository_GetByPathAndVersion_Call { _c.Call.Return(run) return _c } @@ -522,23 +368,23 @@ func (_c *MockSecretRepository_HardDelete_Call) RunAndReturn(run func(ctx contex } // ListCursor provides a mock function for the type MockSecretRepository -func (_mock *MockSecretRepository) ListCursor(ctx context.Context, afterPath *string, limit int) ([]*domain0.Secret, error) { +func (_mock *MockSecretRepository) ListCursor(ctx context.Context, afterPath *string, limit int) ([]*domain.Secret, error) { ret := _mock.Called(ctx, afterPath, limit) if len(ret) == 0 { panic("no return value specified for ListCursor") } - var r0 []*domain0.Secret + var r0 []*domain.Secret var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain0.Secret, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain.Secret, error)); ok { return returnFunc(ctx, afterPath, limit) } - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain0.Secret); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain.Secret); ok { r0 = returnFunc(ctx, afterPath, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*domain0.Secret) + r0 = ret.Get(0).([]*domain.Secret) } } if returnFunc, ok := ret.Get(1).(func(context.Context, *string, int) error); ok { @@ -585,12 +431,12 @@ func (_c *MockSecretRepository_ListCursor_Call) Run(run func(ctx context.Context return _c } -func (_c *MockSecretRepository_ListCursor_Call) Return(secrets []*domain0.Secret, err error) *MockSecretRepository_ListCursor_Call { +func (_c *MockSecretRepository_ListCursor_Call) Return(secrets []*domain.Secret, err error) *MockSecretRepository_ListCursor_Call { _c.Call.Return(secrets, err) return _c } -func (_c *MockSecretRepository_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterPath *string, limit int) ([]*domain0.Secret, error)) *MockSecretRepository_ListCursor_Call { +func (_c *MockSecretRepository_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterPath *string, limit int) ([]*domain.Secret, error)) *MockSecretRepository_ListCursor_Call { _c.Call.Return(run) return _c } @@ -623,23 +469,23 @@ func (_m *MockSecretUseCase) EXPECT() *MockSecretUseCase_Expecter { } // CreateOrUpdate provides a mock function for the type MockSecretUseCase -func (_mock *MockSecretUseCase) CreateOrUpdate(ctx context.Context, path string, value []byte) (*domain0.Secret, error) { +func (_mock *MockSecretUseCase) CreateOrUpdate(ctx context.Context, path string, value []byte) (*domain.Secret, error) { ret := _mock.Called(ctx, path, value) if len(ret) == 0 { panic("no return value specified for CreateOrUpdate") } - var r0 *domain0.Secret + var r0 *domain.Secret var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte) (*domain0.Secret, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte) (*domain.Secret, error)); ok { return returnFunc(ctx, path, value) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte) *domain0.Secret); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte) *domain.Secret); ok { r0 = returnFunc(ctx, path, value) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.Secret) + r0 = ret.Get(0).(*domain.Secret) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string, []byte) error); ok { @@ -686,12 +532,12 @@ func (_c *MockSecretUseCase_CreateOrUpdate_Call) Run(run func(ctx context.Contex return _c } -func (_c *MockSecretUseCase_CreateOrUpdate_Call) Return(secret *domain0.Secret, err error) *MockSecretUseCase_CreateOrUpdate_Call { +func (_c *MockSecretUseCase_CreateOrUpdate_Call) Return(secret *domain.Secret, err error) *MockSecretUseCase_CreateOrUpdate_Call { _c.Call.Return(secret, err) return _c } -func (_c *MockSecretUseCase_CreateOrUpdate_Call) RunAndReturn(run func(ctx context.Context, path string, value []byte) (*domain0.Secret, error)) *MockSecretUseCase_CreateOrUpdate_Call { +func (_c *MockSecretUseCase_CreateOrUpdate_Call) RunAndReturn(run func(ctx context.Context, path string, value []byte) (*domain.Secret, error)) *MockSecretUseCase_CreateOrUpdate_Call { _c.Call.Return(run) return _c } @@ -754,23 +600,23 @@ func (_c *MockSecretUseCase_Delete_Call) RunAndReturn(run func(ctx context.Conte } // Get provides a mock function for the type MockSecretUseCase -func (_mock *MockSecretUseCase) Get(ctx context.Context, path string) (*domain0.Secret, error) { +func (_mock *MockSecretUseCase) Get(ctx context.Context, path string) (*domain.Secret, error) { ret := _mock.Called(ctx, path) if len(ret) == 0 { panic("no return value specified for Get") } - var r0 *domain0.Secret + var r0 *domain.Secret var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain0.Secret, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain.Secret, error)); ok { return returnFunc(ctx, path) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain0.Secret); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain.Secret); ok { r0 = returnFunc(ctx, path) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.Secret) + r0 = ret.Get(0).(*domain.Secret) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { @@ -811,34 +657,34 @@ func (_c *MockSecretUseCase_Get_Call) Run(run func(ctx context.Context, path str return _c } -func (_c *MockSecretUseCase_Get_Call) Return(secret *domain0.Secret, err error) *MockSecretUseCase_Get_Call { +func (_c *MockSecretUseCase_Get_Call) Return(secret *domain.Secret, err error) *MockSecretUseCase_Get_Call { _c.Call.Return(secret, err) return _c } -func (_c *MockSecretUseCase_Get_Call) RunAndReturn(run func(ctx context.Context, path string) (*domain0.Secret, error)) *MockSecretUseCase_Get_Call { +func (_c *MockSecretUseCase_Get_Call) RunAndReturn(run func(ctx context.Context, path string) (*domain.Secret, error)) *MockSecretUseCase_Get_Call { _c.Call.Return(run) return _c } // GetByVersion provides a mock function for the type MockSecretUseCase -func (_mock *MockSecretUseCase) GetByVersion(ctx context.Context, path string, version uint) (*domain0.Secret, error) { +func (_mock *MockSecretUseCase) GetByVersion(ctx context.Context, path string, version uint) (*domain.Secret, error) { ret := _mock.Called(ctx, path, version) if len(ret) == 0 { panic("no return value specified for GetByVersion") } - var r0 *domain0.Secret + var r0 *domain.Secret var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain0.Secret, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain.Secret, error)); ok { return returnFunc(ctx, path, version) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain0.Secret); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain.Secret); ok { r0 = returnFunc(ctx, path, version) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.Secret) + r0 = ret.Get(0).(*domain.Secret) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) error); ok { @@ -885,34 +731,34 @@ func (_c *MockSecretUseCase_GetByVersion_Call) Run(run func(ctx context.Context, return _c } -func (_c *MockSecretUseCase_GetByVersion_Call) Return(secret *domain0.Secret, err error) *MockSecretUseCase_GetByVersion_Call { +func (_c *MockSecretUseCase_GetByVersion_Call) Return(secret *domain.Secret, err error) *MockSecretUseCase_GetByVersion_Call { _c.Call.Return(secret, err) return _c } -func (_c *MockSecretUseCase_GetByVersion_Call) RunAndReturn(run func(ctx context.Context, path string, version uint) (*domain0.Secret, error)) *MockSecretUseCase_GetByVersion_Call { +func (_c *MockSecretUseCase_GetByVersion_Call) RunAndReturn(run func(ctx context.Context, path string, version uint) (*domain.Secret, error)) *MockSecretUseCase_GetByVersion_Call { _c.Call.Return(run) return _c } // ListCursor provides a mock function for the type MockSecretUseCase -func (_mock *MockSecretUseCase) ListCursor(ctx context.Context, afterPath *string, limit int) ([]*domain0.Secret, error) { +func (_mock *MockSecretUseCase) ListCursor(ctx context.Context, afterPath *string, limit int) ([]*domain.Secret, error) { ret := _mock.Called(ctx, afterPath, limit) if len(ret) == 0 { panic("no return value specified for ListCursor") } - var r0 []*domain0.Secret + var r0 []*domain.Secret var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain0.Secret, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain.Secret, error)); ok { return returnFunc(ctx, afterPath, limit) } - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain0.Secret); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain.Secret); ok { r0 = returnFunc(ctx, afterPath, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*domain0.Secret) + r0 = ret.Get(0).([]*domain.Secret) } } if returnFunc, ok := ret.Get(1).(func(context.Context, *string, int) error); ok { @@ -959,12 +805,12 @@ func (_c *MockSecretUseCase_ListCursor_Call) Run(run func(ctx context.Context, a return _c } -func (_c *MockSecretUseCase_ListCursor_Call) Return(secrets []*domain0.Secret, err error) *MockSecretUseCase_ListCursor_Call { +func (_c *MockSecretUseCase_ListCursor_Call) Return(secrets []*domain.Secret, err error) *MockSecretUseCase_ListCursor_Call { _c.Call.Return(secrets, err) return _c } -func (_c *MockSecretUseCase_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterPath *string, limit int) ([]*domain0.Secret, error)) *MockSecretUseCase_ListCursor_Call { +func (_c *MockSecretUseCase_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterPath *string, limit int) ([]*domain.Secret, error)) *MockSecretUseCase_ListCursor_Call { _c.Call.Return(run) return _c } diff --git a/internal/secrets/usecase/secret_usecase.go b/internal/secrets/usecase/secret_usecase.go index be6201e..3015680 100644 --- a/internal/secrets/usecase/secret_usecase.go +++ b/internal/secrets/usecase/secret_usecase.go @@ -1,6 +1,4 @@ // Package usecase implements business logic orchestration for secret management. -// This package coordinates between cryptographic services, repositories, and domain logic -// to implement secure secret storage and retrieval with automatic versioning. package usecase import ( @@ -11,20 +9,16 @@ import ( "github.com/google/uuid" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoService "github.com/allisson/secrets/internal/crypto/service" "github.com/allisson/secrets/internal/database" + "github.com/allisson/secrets/internal/keyring" secretsDomain "github.com/allisson/secrets/internal/secrets/domain" ) // secretUseCase implements the SecretUseCase interface for managing secrets. type secretUseCase struct { txManager database.TxManager - dekRepo DekRepository + keyring keyring.Keyring secretRepo SecretRepository - kekChain *cryptoDomain.KekChain - aeadManager cryptoService.AEADManager - keyManager cryptoService.KeyManager - dekAlgorithm cryptoDomain.Algorithm secretValueSizeLimit int } @@ -34,38 +28,18 @@ func (s *secretUseCase) CreateOrUpdate( path string, value []byte, ) (*secretsDomain.Secret, error) { - // Validate secret path if err := validateSecretPath(path); err != nil { return nil, err } - // Check if the secret value size exceeds the limit if len(value) > s.secretValueSizeLimit { return nil, secretsDomain.ErrSecretValueTooLarge } - activeKek, found := s.kekChain.Get(s.kekChain.ActiveKekID()) - if !found { - return nil, cryptoDomain.ErrKekNotFound - } - - return s.createOrUpdateSecret(ctx, path, value, activeKek) -} - -// createOrUpdateSecret is a helper method that handles the secret creation/update logic. -func (s *secretUseCase) createOrUpdateSecret( - ctx context.Context, - path string, - value []byte, - kek *cryptoDomain.Kek, -) (*secretsDomain.Secret, error) { - // Execute the creation within a transaction var newSecret *secretsDomain.Secret err := s.txManager.WithTx(ctx, func(txCtx context.Context) error { + // Version lookup happens inside the transaction to avoid races. var version uint = 1 - - // Check if secret already exists to determine the version - // This must happen inside the transaction to prevent race conditions existingSecret, err := s.secretRepo.GetByPath(txCtx, path) if err != nil && !errors.Is(err, secretsDomain.ErrSecretNotFound) { return err @@ -73,55 +47,23 @@ func (s *secretUseCase) createOrUpdateSecret( if existingSecret != nil { version = existingSecret.Version + 1 } - // Create a new DEK for this secret - dek, err := s.keyManager.CreateDek(kek, s.dekAlgorithm) - if err != nil { - return err - } - // Persist the DEK - if err := s.dekRepo.Create(txCtx, &dek); err != nil { - return err - } - - // Decrypt the DEK first to get the plaintext key - dekKey, err := s.keyManager.DecryptDek(&dek, kek) - if err != nil { - return err - } - defer cryptoDomain.Zero(dekKey) - - // Create cipher with the decrypted DEK key - cipher, err := s.aeadManager.CreateCipher(dekKey, s.dekAlgorithm) - if err != nil { - return err - } - - // Encrypt the secret value - ciphertext, nonce, err := cipher.Encrypt(value, nil) + env, err := s.keyring.Encrypt(txCtx, value) if err != nil { return err } - // Create the secret entity newSecret = &secretsDomain.Secret{ ID: uuid.Must(uuid.NewV7()), Path: path, Version: version, - DekID: dek.ID, - Ciphertext: ciphertext, - Nonce: nonce, + DekID: env.DekID, + Ciphertext: env.Ciphertext, + Nonce: env.Nonce, CreatedAt: time.Now().UTC(), } - - // Persist the secret - if err := s.secretRepo.Create(txCtx, newSecret); err != nil { - return err - } - - return nil + return s.secretRepo.Create(txCtx, newSecret) }) - if err != nil { return nil, err } @@ -131,12 +73,10 @@ func (s *secretUseCase) createOrUpdateSecret( // Get retrieves and decrypts a secret by its path (latest version). func (s *secretUseCase) Get(ctx context.Context, path string) (*secretsDomain.Secret, error) { - // Retrieve the secret by path secret, err := s.secretRepo.GetByPath(ctx, path) if err != nil { return nil, err } - return s.decryptSecret(ctx, secret) } @@ -146,60 +86,38 @@ func (s *secretUseCase) GetByVersion( path string, version uint, ) (*secretsDomain.Secret, error) { - // Retrieve the secret by path and version secret, err := s.secretRepo.GetByPathAndVersion(ctx, path, version) if err != nil { return nil, err } - return s.decryptSecret(ctx, secret) } -// decryptSecret is a helper method that decrypts a secret's ciphertext. func (s *secretUseCase) decryptSecret( ctx context.Context, secret *secretsDomain.Secret, ) (*secretsDomain.Secret, error) { - // Retrieve the DEK - dek, err := s.dekRepo.Get(ctx, secret.DekID) - if err != nil { - return nil, err - } - - // Retrieve the KEK needed to decrypt the DEK - kek, found := s.kekChain.Get(dek.KekID) - if !found { - return nil, cryptoDomain.ErrKekNotFound - } - - // Decrypt the DEK - dekKey, err := s.keyManager.DecryptDek(dek, kek) - if err != nil { - return nil, err - } - defer cryptoDomain.Zero(dekKey) - - // Create cipher with the decrypted DEK key - cipher, err := s.aeadManager.CreateCipher(dekKey, dek.Algorithm) - if err != nil { - return nil, err - } - - // Decrypt the secret value - plaintext, err := cipher.Decrypt(secret.Ciphertext, secret.Nonce, nil) + plaintext, err := s.keyring.Decrypt(ctx, keyring.Envelope{ + DekID: secret.DekID, + Ciphertext: secret.Ciphertext, + Nonce: secret.Nonce, + }) if err != nil { + // Preserve the existing error contract: decryption failures surface as + // ErrDecryptionFailed; lookup failures (DEK/KEK missing) pass through. + if errors.Is(err, cryptoDomain.ErrDekNotFound) || + errors.Is(err, cryptoDomain.ErrKekNotFound) { + return nil, err + } return nil, cryptoDomain.ErrDecryptionFailed } - // Populate the plaintext field secret.Plaintext = plaintext - return secret, nil } // Delete performs a soft delete on all versions of a secret by its path. func (s *secretUseCase) Delete(ctx context.Context, path string) error { - // Perform soft delete on all versions return s.secretRepo.Delete(ctx, path) } @@ -213,8 +131,6 @@ func (s *secretUseCase) ListCursor( } // PurgeDeleted permanently removes soft-deleted secrets older than specified days. -// If dryRun is true, returns count without performing deletion. -// Returns the number of secrets that were (or would be) deleted. func (s *secretUseCase) PurgeDeleted(ctx context.Context, olderThanDays int, dryRun bool) (int64, error) { if olderThanDays < 0 { return 0, errors.New("olderThanDays must be non-negative") @@ -224,25 +140,17 @@ func (s *secretUseCase) PurgeDeleted(ctx context.Context, olderThanDays int, dry return s.secretRepo.HardDelete(ctx, olderThan, dryRun) } -// NewSecretUseCase creates a new secret use case instance with the provided dependencies. +// NewSecretUseCase creates a new secret use case backed by a Keyring. func NewSecretUseCase( txManager database.TxManager, - dekRepo DekRepository, + kr keyring.Keyring, secretRepo SecretRepository, - kekChain *cryptoDomain.KekChain, - aeadManager cryptoService.AEADManager, - keyManager cryptoService.KeyManager, - dekAlgorithm cryptoDomain.Algorithm, secretValueSizeLimit int, ) SecretUseCase { return &secretUseCase{ txManager: txManager, - dekRepo: dekRepo, + keyring: kr, secretRepo: secretRepo, - kekChain: kekChain, - aeadManager: aeadManager, - keyManager: keyManager, - dekAlgorithm: dekAlgorithm, secretValueSizeLimit: secretValueSizeLimit, } } diff --git a/internal/secrets/usecase/secret_usecase_test.go b/internal/secrets/usecase/secret_usecase_test.go index c694575..8281337 100644 --- a/internal/secrets/usecase/secret_usecase_test.go +++ b/internal/secrets/usecase/secret_usecase_test.go @@ -1,4 +1,4 @@ -package usecase +package usecase_test import ( "context" @@ -9,1624 +9,370 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoServiceMocks "github.com/allisson/secrets/internal/crypto/service/mocks" - databaseMocks "github.com/allisson/secrets/internal/database/mocks" apperrors "github.com/allisson/secrets/internal/errors" + "github.com/allisson/secrets/internal/keyring" secretsDomain "github.com/allisson/secrets/internal/secrets/domain" - secretsUsecaseMocks "github.com/allisson/secrets/internal/secrets/usecase/mocks" + "github.com/allisson/secrets/internal/secrets/usecase" + "github.com/allisson/secrets/internal/secrets/usecase/mocks" ) -// TestSecretUseCase_CreateOrUpdate tests the CreateOrUpdate method of secretUseCase. +// noopTxManager runs the function with no real transaction. The secrets +// use case writes nothing the in-memory keyring Fake cares about, so the +// outer transaction is not load-bearing for these unit tests. +type noopTxManager struct{} + +func (noopTxManager) WithTx(ctx context.Context, fn func(ctx context.Context) error) error { + return fn(ctx) +} + +// newSecretUseCase builds a SecretUseCase wired to a Fake keyring and a +// mocked SecretRepository. +func newSecretUseCase( + t *testing.T, + sizeLimit int, +) (usecase.SecretUseCase, *keyring.Fake, *mocks.MockSecretRepository) { + t.Helper() + fake := keyring.NewFake() + repo := mocks.NewMockSecretRepository(t) + uc := usecase.NewSecretUseCase(noopTxManager{}, fake, repo, sizeLimit) + return uc, fake, repo +} + +// ============================================================================= +// CreateOrUpdate +// ============================================================================= + func TestSecretUseCase_CreateOrUpdate(t *testing.T) { t.Parallel() ctx := context.Background() t.Run("Success_CreateNewSecret", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Create test data - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - - path := "app/api-key" - value := []byte("secret-value") - - dekID := uuid.Must(uuid.NewV7()) - dek := cryptoDomain.Dek{ - ID: dekID, - KekID: kekID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - CreatedAt: time.Now().UTC(), - } + uc, fake, repo := newSecretUseCase(t, 1024) + path := "app/db-password" + value := []byte("super-secret") - ciphertext := []byte("encrypted-secret") - nonce := []byte("secret-nonce") - dekKey := make([]byte, 32) - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPath(mock.Anything, path). - Return(nil, secretsDomain.ErrSecretNotFound). - Once() - - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(nil). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(dek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, &dek). - Return(nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(&dek, kek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Encrypt(value, mock.Anything). - Return(ciphertext, nonce, nil). - Once() - - mockSecretRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(secret *secretsDomain.Secret) bool { - return secret.Path == path && - secret.Version == 1 && - secret.DekID == dekID + repo.EXPECT(). + GetByPath(ctx, path). + Return(nil, secretsDomain.ErrSecretNotFound) + repo.EXPECT(). + Create(ctx, mock.MatchedBy(func(s *secretsDomain.Secret) bool { + return s.Path == path && s.Version == 1 && len(s.Ciphertext) > 0 })). - Return(nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - secret, err := uc.CreateOrUpdate(ctx, path, value) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, secret) - assert.Equal(t, path, secret.Path) - assert.Equal(t, uint(1), secret.Version) - assert.Equal(t, dekID, secret.DekID) + Return(nil) + + got, err := uc.CreateOrUpdate(ctx, path, value) + require.NoError(t, err) + assert.Equal(t, path, got.Path) + assert.EqualValues(t, 1, got.Version) + assert.NotEqual(t, uuid.Nil, got.DekID) + + // Round-trip: decrypting via the same fake should give back value. + plaintext, err := fake.Decrypt(ctx, keyring.Envelope{ + DekID: got.DekID, + Ciphertext: got.Ciphertext, + Nonce: got.Nonce, + }) + require.NoError(t, err) + assert.Equal(t, value, plaintext) }) - t.Run("Success_UpdateExistingSecret", func(t *testing.T) { + t.Run("Success_UpdateExistingSecret_VersionIncrements", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Create test data - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - + uc, _, repo := newSecretUseCase(t, 1024) path := "app/api-key" - value := []byte("new-secret-value") - - existingSecret := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 2, - DekID: uuid.Must(uuid.NewV7()), - Ciphertext: []byte("old-encrypted-secret"), - Nonce: []byte("old-nonce"), - CreatedAt: time.Now().UTC(), - } + existing := &secretsDomain.Secret{Path: path, Version: 3} - dekID := uuid.Must(uuid.NewV7()) - dek := cryptoDomain.Dek{ - ID: dekID, - KekID: kekID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - CreatedAt: time.Now().UTC(), - } + repo.EXPECT().GetByPath(ctx, path).Return(existing, nil) + repo.EXPECT().Create(ctx, mock.MatchedBy(func(s *secretsDomain.Secret) bool { + return s.Version == 4 + })).Return(nil) - ciphertext := []byte("new-encrypted-secret") - nonce := []byte("new-secret-nonce") - dekKey := make([]byte, 32) - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPath(mock.Anything, path). - Return(existingSecret, nil). - Once() - - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(nil). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(dek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, &dek). - Return(nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(&dek, kek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Encrypt(value, mock.Anything). - Return(ciphertext, nonce, nil). - Once() - - mockSecretRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(secret *secretsDomain.Secret) bool { - return secret.Path == path && - secret.Version == 3 && // Incremented version - secret.DekID == dekID - })). - Return(nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - secret, err := uc.CreateOrUpdate(ctx, path, value) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, secret) - assert.Equal(t, path, secret.Path) - assert.Equal(t, uint(3), secret.Version) - assert.Equal(t, dekID, secret.DekID) + got, err := uc.CreateOrUpdate(ctx, path, []byte("v4")) + require.NoError(t, err) + assert.EqualValues(t, 4, got.Version) }) - t.Run("Error_ActiveKekNotFound", func(t *testing.T) { + t.Run("Error_KeyringEncryptFails", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) + uc, fake, repo := newSecretUseCase(t, 1024) + boom := errors.New("kms unavailable") + fake.FailEncrypt = boom - // Create empty KEK chain (no KEKs available) - kekChain := createKekChain([]*cryptoDomain.Kek{}) - defer kekChain.Close() + repo.EXPECT(). + GetByPath(ctx, "p"). + Return(nil, secretsDomain.ErrSecretNotFound) - path := "app/api-key" - value := []byte("secret-value") - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - secret, err := uc.CreateOrUpdate(ctx, path, value) - - // Assert - assert.Error(t, err) - assert.Nil(t, secret) - assert.True(t, errors.Is(err, cryptoDomain.ErrKekNotFound)) + _, err := uc.CreateOrUpdate(ctx, "p", []byte("x")) + assert.ErrorIs(t, err, boom) }) t.Run("Error_SecretRepoGetByPathFails", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() + uc, _, repo := newSecretUseCase(t, 1024) + boom := errors.New("db down") + repo.EXPECT().GetByPath(ctx, "p").Return(nil, boom) - path := "app/api-key" - value := []byte("secret-value") - expectedError := errors.New("database error") - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - RunAndReturn(func(ctx context.Context, fn func(context.Context) error) error { - return fn(ctx) - }). - Once() - - mockSecretRepo.EXPECT(). - GetByPath(mock.Anything, path). - Return(nil, expectedError). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - secret, err := uc.CreateOrUpdate(ctx, path, value) - - // Assert - assert.Error(t, err) - assert.Nil(t, secret) - assert.Equal(t, expectedError, err) + _, err := uc.CreateOrUpdate(ctx, "p", []byte("x")) + assert.ErrorIs(t, err, boom) }) - t.Run("Error_CreateDekFails", func(t *testing.T) { + t.Run("Error_SecretRepoCreateFails", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() + uc, _, repo := newSecretUseCase(t, 1024) + boom := errors.New("insert failed") - path := "app/api-key" - value := []byte("secret-value") - expectedError := errors.New("failed to create dek") - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPath(mock.Anything, path). - Return(nil, secretsDomain.ErrSecretNotFound). - Once() - - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(expectedError). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(cryptoDomain.Dek{}, expectedError). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - secret, err := uc.CreateOrUpdate(ctx, path, value) - - // Assert - assert.Error(t, err) - assert.Nil(t, secret) - assert.Equal(t, expectedError, err) + repo.EXPECT().GetByPath(ctx, "p").Return(nil, secretsDomain.ErrSecretNotFound) + repo.EXPECT().Create(ctx, mock.Anything).Return(boom) + + _, err := uc.CreateOrUpdate(ctx, "p", []byte("x")) + assert.ErrorIs(t, err, boom) }) t.Run("Error_SecretValueTooLarge", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) + uc, _, _ := newSecretUseCase(t, 4) - kekChain := createKekChain([]*cryptoDomain.Kek{}) - defer kekChain.Close() - - path := "app/api-key" - value := make([]byte, 10) // 10 bytes - - // Use a limit of 5 bytes - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 5, - ) - secret, err := uc.CreateOrUpdate(ctx, path, value) - - // Assert - assert.Error(t, err) - assert.Nil(t, secret) - assert.True(t, errors.Is(err, secretsDomain.ErrSecretValueTooLarge)) + _, err := uc.CreateOrUpdate(ctx, "p", []byte("too-long-value")) + assert.ErrorIs(t, err, secretsDomain.ErrSecretValueTooLarge) }) t.Run("Error_InvalidPath", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekChain := createKekChain([]*cryptoDomain.Kek{}) - defer kekChain.Close() - - path := "/invalid/path" - value := []byte("secret-value") - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - secret, err := uc.CreateOrUpdate(ctx, path, value) - - // Assert - assert.Error(t, err) - assert.Nil(t, secret) - assert.Equal(t, "invalid secret path format: invalid input", err.Error()) + uc, _, _ := newSecretUseCase(t, 1024) + + _, err := uc.CreateOrUpdate(ctx, "/leading-slash", []byte("x")) + assert.ErrorIs(t, err, secretsDomain.ErrInvalidSecretPath) }) } -// TestSecretUseCase_Get tests the Get method of secretUseCase. +// ============================================================================= +// Get / GetByVersion +// ============================================================================= + func TestSecretUseCase_Get(t *testing.T) { t.Parallel() ctx := context.Background() t.Run("Success_GetAndDecryptSecret", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Create test data - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() + uc, fake, repo := newSecretUseCase(t, 1024) - path := "app/api-key" - dekID := uuid.Must(uuid.NewV7()) - ciphertext := []byte("encrypted-secret") - nonce := []byte("secret-nonce") - plaintext := []byte("secret-value") + plaintext := []byte("payload") + env, err := fake.Encrypt(ctx, plaintext) + require.NoError(t, err) - secret := &secretsDomain.Secret{ + stored := &secretsDomain.Secret{ ID: uuid.Must(uuid.NewV7()), - Path: path, + Path: "p", Version: 1, - DekID: dekID, - Ciphertext: ciphertext, - Nonce: nonce, + DekID: env.DekID, + Ciphertext: env.Ciphertext, + Nonce: env.Nonce, CreatedAt: time.Now().UTC(), } + repo.EXPECT().GetByPath(ctx, "p").Return(stored, nil) - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: kekID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - CreatedAt: time.Now().UTC(), - } - - dekKey := make([]byte, 32) - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPath(ctx, path). - Return(secret, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Decrypt(ciphertext, nonce, mock.Anything). - Return(plaintext, nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.Get(ctx, path) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, path, result.Path) - assert.Equal(t, plaintext, result.Plaintext) + got, err := uc.Get(ctx, "p") + require.NoError(t, err) + assert.Equal(t, plaintext, got.Plaintext) }) t.Run("Error_SecretNotFound", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - - path := "app/nonexistent" + uc, _, repo := newSecretUseCase(t, 1024) + repo.EXPECT().GetByPath(ctx, "missing").Return(nil, secretsDomain.ErrSecretNotFound) - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPath(ctx, path). - Return(nil, secretsDomain.ErrSecretNotFound). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.Get(ctx, path) - - // Assert - assert.Error(t, err) - assert.Nil(t, result) - assert.True(t, errors.Is(err, apperrors.ErrNotFound)) + _, err := uc.Get(ctx, "missing") + assert.ErrorIs(t, err, secretsDomain.ErrSecretNotFound) }) - t.Run("Error_DekNotFound", func(t *testing.T) { + t.Run("Error_DekNotFound_Propagates", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() + uc, fake, repo := newSecretUseCase(t, 1024) + fake.FailDecrypt = cryptoDomain.ErrDekNotFound - path := "app/api-key" - dekID := uuid.Must(uuid.NewV7()) + repo.EXPECT().GetByPath(ctx, "p").Return(&secretsDomain.Secret{ + DekID: uuid.New(), + }, nil) - secret := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 1, - DekID: dekID, - Ciphertext: []byte("encrypted-secret"), - Nonce: []byte("secret-nonce"), - CreatedAt: time.Now().UTC(), - } - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPath(ctx, path). - Return(secret, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(nil, cryptoDomain.ErrDekNotFound). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.Get(ctx, path) - - // Assert - assert.Error(t, err) - assert.Nil(t, result) - assert.True(t, errors.Is(err, cryptoDomain.ErrDekNotFound)) + _, err := uc.Get(ctx, "p") + assert.ErrorIs(t, err, cryptoDomain.ErrDekNotFound) }) - t.Run("Error_KekNotFound", func(t *testing.T) { + t.Run("Error_KekNotFound_Propagates", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - - path := "app/api-key" - dekID := uuid.Must(uuid.NewV7()) - differentKekID := uuid.Must(uuid.NewV7()) - - secret := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 1, - DekID: dekID, - Ciphertext: []byte("encrypted-secret"), - Nonce: []byte("secret-nonce"), - CreatedAt: time.Now().UTC(), - } + uc, fake, repo := newSecretUseCase(t, 1024) + fake.FailDecrypt = cryptoDomain.ErrKekNotFound - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: differentKekID, // Different KEK ID not in the chain - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - CreatedAt: time.Now().UTC(), - } + repo.EXPECT().GetByPath(ctx, "p").Return(&secretsDomain.Secret{ + DekID: uuid.New(), + }, nil) - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPath(ctx, path). - Return(secret, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.Get(ctx, path) - - // Assert - assert.Error(t, err) - assert.Nil(t, result) - assert.True(t, errors.Is(err, cryptoDomain.ErrKekNotFound)) + _, err := uc.Get(ctx, "p") + assert.ErrorIs(t, err, cryptoDomain.ErrKekNotFound) }) - t.Run("Error_DecryptionFailed", func(t *testing.T) { + t.Run("Error_DecryptionFailed_GenericErrorWraps", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Create test data - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() + uc, fake, repo := newSecretUseCase(t, 1024) + fake.FailDecrypt = errors.New("AEAD tag mismatch") - path := "app/api-key" - dekID := uuid.Must(uuid.NewV7()) - ciphertext := []byte("encrypted-secret") - nonce := []byte("secret-nonce") + repo.EXPECT().GetByPath(ctx, "p").Return(&secretsDomain.Secret{ + DekID: uuid.New(), + }, nil) - secret := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: 1, - DekID: dekID, - Ciphertext: ciphertext, - Nonce: nonce, - CreatedAt: time.Now().UTC(), - } - - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: kekID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - CreatedAt: time.Now().UTC(), - } - - dekKey := make([]byte, 32) - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPath(ctx, path). - Return(secret, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Decrypt(ciphertext, nonce, mock.Anything). - Return(nil, errors.New("decryption failed")). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.Get(ctx, path) - - // Assert - assert.Error(t, err) - assert.Nil(t, result) - assert.True(t, errors.Is(err, cryptoDomain.ErrDecryptionFailed)) + _, err := uc.Get(ctx, "p") + assert.ErrorIs(t, err, cryptoDomain.ErrDecryptionFailed) }) } -// TestSecretUseCase_Delete tests the Delete method of secretUseCase. -func TestSecretUseCase_Delete(t *testing.T) { - t.Parallel() - ctx := context.Background() - - t.Run("Success_DeleteSecret", func(t *testing.T) { - t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - - path := "app/api-key" - - // Setup expectations - mockSecretRepo.EXPECT(). - Delete(ctx, path). - Return(nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - err := uc.Delete(ctx, path) - - // Assert - assert.NoError(t, err) - }) - - t.Run("Error_SecretNotFound", func(t *testing.T) { - t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - - path := "app/nonexistent" - - // Setup expectations - mockSecretRepo.EXPECT(). - Delete(ctx, path). - Return(secretsDomain.ErrSecretNotFound). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - err := uc.Delete(ctx, path) - - // Assert - assert.Error(t, err) - assert.True(t, errors.Is(err, apperrors.ErrNotFound)) - }) - - t.Run("Error_DeleteFails", func(t *testing.T) { - t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - - path := "app/api-key" - expectedError := errors.New("database error") - - // Setup expectations - mockSecretRepo.EXPECT(). - Delete(ctx, path). - Return(expectedError). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - err := uc.Delete(ctx, path) - - // Assert - assert.Error(t, err) - assert.Equal(t, expectedError, err) - }) -} - -// TestSecretUseCase_GetByVersion tests the GetByVersion method of secretUseCase. func TestSecretUseCase_GetByVersion(t *testing.T) { t.Parallel() ctx := context.Background() t.Run("Success_GetSpecificVersion", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Create test data - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() + uc, fake, repo := newSecretUseCase(t, 1024) - path := "app/api-key" - version := uint(2) - dekID := uuid.Must(uuid.NewV7()) - ciphertext := []byte("encrypted-secret") - nonce := []byte("secret-nonce") - plaintext := []byte("secret-value") + plaintext := []byte("v7") + env, err := fake.Encrypt(ctx, plaintext) + require.NoError(t, err) - secret := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: version, - DekID: dekID, - Ciphertext: ciphertext, - Nonce: nonce, - CreatedAt: time.Now().UTC(), + stored := &secretsDomain.Secret{ + Path: "p", + Version: 7, + DekID: env.DekID, + Ciphertext: env.Ciphertext, + Nonce: env.Nonce, } + repo.EXPECT().GetByPathAndVersion(ctx, "p", uint(7)).Return(stored, nil) - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: kekID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - CreatedAt: time.Now().UTC(), - } - - dekKey := make([]byte, 32) - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPathAndVersion(ctx, path, version). - Return(secret, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Decrypt(ciphertext, nonce, mock.Anything). - Return(plaintext, nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.GetByVersion(ctx, path, version) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, path, result.Path) - assert.Equal(t, version, result.Version) - assert.Equal(t, plaintext, result.Plaintext) + got, err := uc.GetByVersion(ctx, "p", 7) + require.NoError(t, err) + assert.Equal(t, plaintext, got.Plaintext) }) t.Run("Error_SecretNotFound", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - - path := "app/nonexistent" - version := uint(1) - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPathAndVersion(ctx, path, version). - Return(nil, secretsDomain.ErrSecretNotFound). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.GetByVersion(ctx, path, version) - - // Assert - assert.Error(t, err) - assert.Nil(t, result) - assert.True(t, errors.Is(err, apperrors.ErrNotFound)) + uc, _, repo := newSecretUseCase(t, 1024) + repo.EXPECT(). + GetByPathAndVersion(ctx, "p", uint(99)). + Return(nil, secretsDomain.ErrSecretNotFound) + + _, err := uc.GetByVersion(ctx, "p", 99) + assert.ErrorIs(t, err, secretsDomain.ErrSecretNotFound) }) t.Run("Error_DecryptionFailed", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() + uc, fake, repo := newSecretUseCase(t, 1024) + fake.FailDecrypt = errors.New("AEAD tag mismatch") + repo.EXPECT(). + GetByPathAndVersion(ctx, "p", uint(1)). + Return(&secretsDomain.Secret{DekID: uuid.New()}, nil) + + _, err := uc.GetByVersion(ctx, "p", 1) + assert.ErrorIs(t, err, cryptoDomain.ErrDecryptionFailed) + }) +} - path := "app/api-key" - version := uint(1) - dekID := uuid.Must(uuid.NewV7()) - ciphertext := []byte("encrypted-secret") - nonce := []byte("secret-nonce") +// ============================================================================= +// Delete +// ============================================================================= - secret := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: version, - DekID: dekID, - Ciphertext: ciphertext, - Nonce: nonce, - CreatedAt: time.Now().UTC(), - } +func TestSecretUseCase_Delete(t *testing.T) { + t.Parallel() + ctx := context.Background() - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: kekID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - CreatedAt: time.Now().UTC(), - } + t.Run("Success_DeleteSecret", func(t *testing.T) { + t.Parallel() + uc, _, repo := newSecretUseCase(t, 1024) + repo.EXPECT().Delete(ctx, "p").Return(nil) - dekKey := make([]byte, 32) - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPathAndVersion(ctx, path, version). - Return(secret, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Decrypt(ciphertext, nonce, mock.Anything). - Return(nil, errors.New("decryption failed")). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.GetByVersion(ctx, path, version) - - // Assert - assert.Error(t, err) - assert.Nil(t, result) - assert.True(t, errors.Is(err, cryptoDomain.ErrDecryptionFailed)) + assert.NoError(t, uc.Delete(ctx, "p")) }) - t.Run("Error_DekNotFound", func(t *testing.T) { + t.Run("Error_SecretNotFound", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - - path := "app/api-key" - version := uint(1) - dekID := uuid.Must(uuid.NewV7()) - - secret := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: version, - DekID: dekID, - Ciphertext: []byte("encrypted-secret"), - Nonce: []byte("secret-nonce"), - CreatedAt: time.Now().UTC(), - } + uc, _, repo := newSecretUseCase(t, 1024) + repo.EXPECT().Delete(ctx, "p").Return(secretsDomain.ErrSecretNotFound) - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPathAndVersion(ctx, path, version). - Return(secret, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(nil, cryptoDomain.ErrDekNotFound). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.GetByVersion(ctx, path, version) - - // Assert - assert.Error(t, err) - assert.Nil(t, result) - assert.True(t, errors.Is(err, apperrors.ErrNotFound)) + err := uc.Delete(ctx, "p") + assert.ErrorIs(t, err, secretsDomain.ErrSecretNotFound) }) - t.Run("Error_KekNotFound", func(t *testing.T) { + t.Run("Error_DeleteFails", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekID := uuid.Must(uuid.NewV7()) - kek := &cryptoDomain.Kek{ - ID: kekID, - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - Key: make([]byte, 32), - EncryptedKey: []byte("encrypted-kek"), - Nonce: []byte("kek-nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } - kekChain := createKekChain([]*cryptoDomain.Kek{kek}) - defer kekChain.Close() - - path := "app/api-key" - version := uint(1) - dekID := uuid.Must(uuid.NewV7()) - differentKekID := uuid.Must(uuid.NewV7()) // Different KEK ID - - secret := &secretsDomain.Secret{ - ID: uuid.Must(uuid.NewV7()), - Path: path, - Version: version, - DekID: dekID, - Ciphertext: []byte("encrypted-secret"), - Nonce: []byte("secret-nonce"), - CreatedAt: time.Now().UTC(), - } + uc, _, repo := newSecretUseCase(t, 1024) + boom := errors.New("db error") + repo.EXPECT().Delete(ctx, "p").Return(boom) - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: differentKekID, // KEK not in chain - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - CreatedAt: time.Now().UTC(), - } - - // Setup expectations - mockSecretRepo.EXPECT(). - GetByPathAndVersion(ctx, path, version). - Return(secret, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - result, err := uc.GetByVersion(ctx, path, version) - - // Assert - assert.Error(t, err) - assert.Nil(t, result) - assert.True(t, errors.Is(err, cryptoDomain.ErrKekNotFound)) + assert.ErrorIs(t, uc.Delete(ctx, "p"), boom) }) } -// TestSecretUseCase_PurgeDeleted tests the PurgeDeleted method of secretUseCase. +// ============================================================================= +// PurgeDeleted +// ============================================================================= + func TestSecretUseCase_PurgeDeleted(t *testing.T) { t.Parallel() ctx := context.Background() t.Run("Success_PurgeDeletedSecrets", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekChain := createKekChain([]*cryptoDomain.Kek{}) - defer kekChain.Close() - - olderThanDays := 30 - dryRun := false - expectedCount := int64(5) - - // Setup expectations - mockSecretRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(expectedCount, nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.NoError(t, err) - assert.Equal(t, expectedCount, count) + uc, _, repo := newSecretUseCase(t, 1024) + repo.EXPECT(). + HardDelete(ctx, mock.MatchedBy(func(tm time.Time) bool { + return !tm.IsZero() + }), false). + Return(int64(5), nil) + + n, err := uc.PurgeDeleted(ctx, 30, false) + require.NoError(t, err) + assert.EqualValues(t, 5, n) }) t.Run("Success_DryRun", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekChain := createKekChain([]*cryptoDomain.Kek{}) - defer kekChain.Close() - - olderThanDays := 60 - dryRun := true - expectedCount := int64(10) - - // Setup expectations - mockSecretRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(expectedCount, nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.NoError(t, err) - assert.Equal(t, expectedCount, count) + uc, _, repo := newSecretUseCase(t, 1024) + repo.EXPECT(). + HardDelete(ctx, mock.Anything, true). + Return(int64(3), nil) + + n, err := uc.PurgeDeleted(ctx, 30, true) + require.NoError(t, err) + assert.EqualValues(t, 3, n) }) t.Run("Success_NoSecretsToDelete", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekChain := createKekChain([]*cryptoDomain.Kek{}) - defer kekChain.Close() - - olderThanDays := 90 - dryRun := false - expectedCount := int64(0) - - // Setup expectations - mockSecretRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(expectedCount, nil). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.NoError(t, err) - assert.Equal(t, expectedCount, count) + uc, _, repo := newSecretUseCase(t, 1024) + repo.EXPECT().HardDelete(ctx, mock.Anything, false).Return(int64(0), nil) + + n, err := uc.PurgeDeleted(ctx, 30, false) + require.NoError(t, err) + assert.Zero(t, n) }) t.Run("Error_NegativeDays", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekChain := createKekChain([]*cryptoDomain.Kek{}) - defer kekChain.Close() - - olderThanDays := -5 - dryRun := false - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert + uc, _, _ := newSecretUseCase(t, 1024) + _, err := uc.PurgeDeleted(ctx, -1, false) assert.Error(t, err) - assert.Equal(t, int64(0), count) - assert.Contains(t, err.Error(), "olderThanDays must be non-negative") }) t.Run("Error_RepositoryFails", func(t *testing.T) { t.Parallel() - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockDekRepo := secretsUsecaseMocks.NewMockDekRepository(t) - mockSecretRepo := secretsUsecaseMocks.NewMockSecretRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - kekChain := createKekChain([]*cryptoDomain.Kek{}) - defer kekChain.Close() - - olderThanDays := 30 - dryRun := false - expectedError := errors.New("database error") - - // Setup expectations - mockSecretRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(int64(0), expectedError). - Once() - - // Execute - uc := NewSecretUseCase( - mockTxManager, - mockDekRepo, - mockSecretRepo, - kekChain, - mockAEADManager, - mockKeyManager, - cryptoDomain.AESGCM, - 524288, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.Error(t, err) - assert.Equal(t, int64(0), count) - assert.Equal(t, expectedError, err) - }) -} + uc, _, repo := newSecretUseCase(t, 1024) + boom := apperrors.Wrap(apperrors.ErrInvalidInput, "boom") + repo.EXPECT().HardDelete(ctx, mock.Anything, false).Return(int64(0), boom) -// createKekChain is a helper function to create a KEK chain for testing. -func createKekChain(keks []*cryptoDomain.Kek) *cryptoDomain.KekChain { - return cryptoDomain.NewKekChain(keks) + _, err := uc.PurgeDeleted(ctx, 30, false) + assert.ErrorIs(t, err, boom) + }) } diff --git a/internal/tokenization/usecase/helpers.go b/internal/tokenization/usecase/helpers.go deleted file mode 100644 index 0a7dd07..0000000 --- a/internal/tokenization/usecase/helpers.go +++ /dev/null @@ -1,17 +0,0 @@ -package usecase - -import ( - "github.com/google/uuid" - - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" -) - -// getKek retrieves a KEK from the chain by its ID. -// Returns ErrKekNotFound if the KEK is not in the chain. -func getKek(kekChain *cryptoDomain.KekChain, kekID uuid.UUID) (*cryptoDomain.Kek, error) { - kek, ok := kekChain.Get(kekID) - if !ok { - return nil, cryptoDomain.ErrKekNotFound - } - return kek, nil -} diff --git a/internal/tokenization/usecase/helpers_test.go b/internal/tokenization/usecase/helpers_test.go deleted file mode 100644 index a7bd595..0000000 --- a/internal/tokenization/usecase/helpers_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package usecase - -import ( - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - tokenizationTesting "github.com/allisson/secrets/internal/tokenization/testing" -) - -func TestGetKek(t *testing.T) { - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - - t.Run("Success_GetActiveKek", func(t *testing.T) { - kek, err := getKek(kekChain, activeKek.ID) - assert.NoError(t, err) - assert.NotNil(t, kek) - assert.Equal(t, activeKek.ID, kek.ID) - }) - - t.Run("Error_KekNotFound", func(t *testing.T) { - randomID := uuid.Must(uuid.NewV7()) - kek, err := getKek(kekChain, randomID) - assert.ErrorIs(t, err, cryptoDomain.ErrKekNotFound) - assert.Nil(t, kek) - }) -} diff --git a/internal/tokenization/usecase/interface.go b/internal/tokenization/usecase/interface.go index 11f852f..6a259ec 100644 --- a/internal/tokenization/usecase/interface.go +++ b/internal/tokenization/usecase/interface.go @@ -12,12 +12,6 @@ import ( tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" ) -// DekRepository defines the interface for DEK persistence operations. -type DekRepository interface { - Create(ctx context.Context, dek *cryptoDomain.Dek) error - Get(ctx context.Context, dekID uuid.UUID) (*cryptoDomain.Dek, error) -} - // TokenizationKeyRepository defines the interface for tokenization key persistence. type TokenizationKeyRepository interface { Create(ctx context.Context, key *tokenizationDomain.TokenizationKey) error diff --git a/internal/tokenization/usecase/mocks/mocks.go b/internal/tokenization/usecase/mocks/mocks.go index 24ef2c1..fafbcba 100644 --- a/internal/tokenization/usecase/mocks/mocks.go +++ b/internal/tokenization/usecase/mocks/mocks.go @@ -8,8 +8,8 @@ import ( "context" "time" - "github.com/allisson/secrets/internal/crypto/domain" - domain0 "github.com/allisson/secrets/internal/tokenization/domain" + domain0 "github.com/allisson/secrets/internal/crypto/domain" + "github.com/allisson/secrets/internal/tokenization/domain" "github.com/google/uuid" mock "github.com/stretchr/testify/mock" ) @@ -98,158 +98,6 @@ func (_c *MockHashService_Hash_Call) RunAndReturn(run func(value []byte, salt [] return _c } -// NewMockDekRepository creates a new instance of MockDekRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockDekRepository(t interface { - mock.TestingT - Cleanup(func()) -}) *MockDekRepository { - mock := &MockDekRepository{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockDekRepository is an autogenerated mock type for the DekRepository type -type MockDekRepository struct { - mock.Mock -} - -type MockDekRepository_Expecter struct { - mock *mock.Mock -} - -func (_m *MockDekRepository) EXPECT() *MockDekRepository_Expecter { - return &MockDekRepository_Expecter{mock: &_m.Mock} -} - -// Create provides a mock function for the type MockDekRepository -func (_mock *MockDekRepository) Create(ctx context.Context, dek *domain.Dek) error { - ret := _mock.Called(ctx, dek) - - if len(ret) == 0 { - panic("no return value specified for Create") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.Dek) error); ok { - r0 = returnFunc(ctx, dek) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockDekRepository_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' -type MockDekRepository_Create_Call struct { - *mock.Call -} - -// Create is a helper method to define mock.On call -// - ctx context.Context -// - dek *domain.Dek -func (_e *MockDekRepository_Expecter) Create(ctx interface{}, dek interface{}) *MockDekRepository_Create_Call { - return &MockDekRepository_Create_Call{Call: _e.mock.On("Create", ctx, dek)} -} - -func (_c *MockDekRepository_Create_Call) Run(run func(ctx context.Context, dek *domain.Dek)) *MockDekRepository_Create_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 *domain.Dek - if args[1] != nil { - arg1 = args[1].(*domain.Dek) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockDekRepository_Create_Call) Return(err error) *MockDekRepository_Create_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockDekRepository_Create_Call) RunAndReturn(run func(ctx context.Context, dek *domain.Dek) error) *MockDekRepository_Create_Call { - _c.Call.Return(run) - return _c -} - -// Get provides a mock function for the type MockDekRepository -func (_mock *MockDekRepository) Get(ctx context.Context, dekID uuid.UUID) (*domain.Dek, error) { - ret := _mock.Called(ctx, dekID) - - if len(ret) == 0 { - panic("no return value specified for Get") - } - - var r0 *domain.Dek - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*domain.Dek, error)); ok { - return returnFunc(ctx, dekID) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) *domain.Dek); ok { - r0 = returnFunc(ctx, dekID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain.Dek) - } - } - if returnFunc, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { - r1 = returnFunc(ctx, dekID) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockDekRepository_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' -type MockDekRepository_Get_Call struct { - *mock.Call -} - -// Get is a helper method to define mock.On call -// - ctx context.Context -// - dekID uuid.UUID -func (_e *MockDekRepository_Expecter) Get(ctx interface{}, dekID interface{}) *MockDekRepository_Get_Call { - return &MockDekRepository_Get_Call{Call: _e.mock.On("Get", ctx, dekID)} -} - -func (_c *MockDekRepository_Get_Call) Run(run func(ctx context.Context, dekID uuid.UUID)) *MockDekRepository_Get_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 uuid.UUID - if args[1] != nil { - arg1 = args[1].(uuid.UUID) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockDekRepository_Get_Call) Return(dek *domain.Dek, err error) *MockDekRepository_Get_Call { - _c.Call.Return(dek, err) - return _c -} - -func (_c *MockDekRepository_Get_Call) RunAndReturn(run func(ctx context.Context, dekID uuid.UUID) (*domain.Dek, error)) *MockDekRepository_Get_Call { - _c.Call.Return(run) - return _c -} - // NewMockTokenizationKeyRepository creates a new instance of MockTokenizationKeyRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockTokenizationKeyRepository(t interface { @@ -278,7 +126,7 @@ func (_m *MockTokenizationKeyRepository) EXPECT() *MockTokenizationKeyRepository } // Create provides a mock function for the type MockTokenizationKeyRepository -func (_mock *MockTokenizationKeyRepository) Create(ctx context.Context, key *domain0.TokenizationKey) error { +func (_mock *MockTokenizationKeyRepository) Create(ctx context.Context, key *domain.TokenizationKey) error { ret := _mock.Called(ctx, key) if len(ret) == 0 { @@ -286,7 +134,7 @@ func (_mock *MockTokenizationKeyRepository) Create(ctx context.Context, key *dom } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain0.TokenizationKey) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.TokenizationKey) error); ok { r0 = returnFunc(ctx, key) } else { r0 = ret.Error(0) @@ -301,20 +149,20 @@ type MockTokenizationKeyRepository_Create_Call struct { // Create is a helper method to define mock.On call // - ctx context.Context -// - key *domain0.TokenizationKey +// - key *domain.TokenizationKey func (_e *MockTokenizationKeyRepository_Expecter) Create(ctx interface{}, key interface{}) *MockTokenizationKeyRepository_Create_Call { return &MockTokenizationKeyRepository_Create_Call{Call: _e.mock.On("Create", ctx, key)} } -func (_c *MockTokenizationKeyRepository_Create_Call) Run(run func(ctx context.Context, key *domain0.TokenizationKey)) *MockTokenizationKeyRepository_Create_Call { +func (_c *MockTokenizationKeyRepository_Create_Call) Run(run func(ctx context.Context, key *domain.TokenizationKey)) *MockTokenizationKeyRepository_Create_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 *domain0.TokenizationKey + var arg1 *domain.TokenizationKey if args[1] != nil { - arg1 = args[1].(*domain0.TokenizationKey) + arg1 = args[1].(*domain.TokenizationKey) } run( arg0, @@ -329,7 +177,7 @@ func (_c *MockTokenizationKeyRepository_Create_Call) Return(err error) *MockToke return _c } -func (_c *MockTokenizationKeyRepository_Create_Call) RunAndReturn(run func(ctx context.Context, key *domain0.TokenizationKey) error) *MockTokenizationKeyRepository_Create_Call { +func (_c *MockTokenizationKeyRepository_Create_Call) RunAndReturn(run func(ctx context.Context, key *domain.TokenizationKey) error) *MockTokenizationKeyRepository_Create_Call { _c.Call.Return(run) return _c } @@ -392,23 +240,23 @@ func (_c *MockTokenizationKeyRepository_Delete_Call) RunAndReturn(run func(ctx c } // Get provides a mock function for the type MockTokenizationKeyRepository -func (_mock *MockTokenizationKeyRepository) Get(ctx context.Context, keyID uuid.UUID) (*domain0.TokenizationKey, error) { +func (_mock *MockTokenizationKeyRepository) Get(ctx context.Context, keyID uuid.UUID) (*domain.TokenizationKey, error) { ret := _mock.Called(ctx, keyID) if len(ret) == 0 { panic("no return value specified for Get") } - var r0 *domain0.TokenizationKey + var r0 *domain.TokenizationKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*domain0.TokenizationKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*domain.TokenizationKey, error)); ok { return returnFunc(ctx, keyID) } - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) *domain0.TokenizationKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) *domain.TokenizationKey); ok { r0 = returnFunc(ctx, keyID) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TokenizationKey) + r0 = ret.Get(0).(*domain.TokenizationKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { @@ -449,34 +297,34 @@ func (_c *MockTokenizationKeyRepository_Get_Call) Run(run func(ctx context.Conte return _c } -func (_c *MockTokenizationKeyRepository_Get_Call) Return(tokenizationKey *domain0.TokenizationKey, err error) *MockTokenizationKeyRepository_Get_Call { +func (_c *MockTokenizationKeyRepository_Get_Call) Return(tokenizationKey *domain.TokenizationKey, err error) *MockTokenizationKeyRepository_Get_Call { _c.Call.Return(tokenizationKey, err) return _c } -func (_c *MockTokenizationKeyRepository_Get_Call) RunAndReturn(run func(ctx context.Context, keyID uuid.UUID) (*domain0.TokenizationKey, error)) *MockTokenizationKeyRepository_Get_Call { +func (_c *MockTokenizationKeyRepository_Get_Call) RunAndReturn(run func(ctx context.Context, keyID uuid.UUID) (*domain.TokenizationKey, error)) *MockTokenizationKeyRepository_Get_Call { _c.Call.Return(run) return _c } // GetByName provides a mock function for the type MockTokenizationKeyRepository -func (_mock *MockTokenizationKeyRepository) GetByName(ctx context.Context, name string) (*domain0.TokenizationKey, error) { +func (_mock *MockTokenizationKeyRepository) GetByName(ctx context.Context, name string) (*domain.TokenizationKey, error) { ret := _mock.Called(ctx, name) if len(ret) == 0 { panic("no return value specified for GetByName") } - var r0 *domain0.TokenizationKey + var r0 *domain.TokenizationKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain0.TokenizationKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain.TokenizationKey, error)); ok { return returnFunc(ctx, name) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain0.TokenizationKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain.TokenizationKey); ok { r0 = returnFunc(ctx, name) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TokenizationKey) + r0 = ret.Get(0).(*domain.TokenizationKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { @@ -517,34 +365,34 @@ func (_c *MockTokenizationKeyRepository_GetByName_Call) Run(run func(ctx context return _c } -func (_c *MockTokenizationKeyRepository_GetByName_Call) Return(tokenizationKey *domain0.TokenizationKey, err error) *MockTokenizationKeyRepository_GetByName_Call { +func (_c *MockTokenizationKeyRepository_GetByName_Call) Return(tokenizationKey *domain.TokenizationKey, err error) *MockTokenizationKeyRepository_GetByName_Call { _c.Call.Return(tokenizationKey, err) return _c } -func (_c *MockTokenizationKeyRepository_GetByName_Call) RunAndReturn(run func(ctx context.Context, name string) (*domain0.TokenizationKey, error)) *MockTokenizationKeyRepository_GetByName_Call { +func (_c *MockTokenizationKeyRepository_GetByName_Call) RunAndReturn(run func(ctx context.Context, name string) (*domain.TokenizationKey, error)) *MockTokenizationKeyRepository_GetByName_Call { _c.Call.Return(run) return _c } // GetByNameAndVersion provides a mock function for the type MockTokenizationKeyRepository -func (_mock *MockTokenizationKeyRepository) GetByNameAndVersion(ctx context.Context, name string, version uint) (*domain0.TokenizationKey, error) { +func (_mock *MockTokenizationKeyRepository) GetByNameAndVersion(ctx context.Context, name string, version uint) (*domain.TokenizationKey, error) { ret := _mock.Called(ctx, name, version) if len(ret) == 0 { panic("no return value specified for GetByNameAndVersion") } - var r0 *domain0.TokenizationKey + var r0 *domain.TokenizationKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain0.TokenizationKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain.TokenizationKey, error)); ok { return returnFunc(ctx, name, version) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain0.TokenizationKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain.TokenizationKey); ok { r0 = returnFunc(ctx, name, version) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TokenizationKey) + r0 = ret.Get(0).(*domain.TokenizationKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) error); ok { @@ -591,12 +439,12 @@ func (_c *MockTokenizationKeyRepository_GetByNameAndVersion_Call) Run(run func(c return _c } -func (_c *MockTokenizationKeyRepository_GetByNameAndVersion_Call) Return(tokenizationKey *domain0.TokenizationKey, err error) *MockTokenizationKeyRepository_GetByNameAndVersion_Call { +func (_c *MockTokenizationKeyRepository_GetByNameAndVersion_Call) Return(tokenizationKey *domain.TokenizationKey, err error) *MockTokenizationKeyRepository_GetByNameAndVersion_Call { _c.Call.Return(tokenizationKey, err) return _c } -func (_c *MockTokenizationKeyRepository_GetByNameAndVersion_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain0.TokenizationKey, error)) *MockTokenizationKeyRepository_GetByNameAndVersion_Call { +func (_c *MockTokenizationKeyRepository_GetByNameAndVersion_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain.TokenizationKey, error)) *MockTokenizationKeyRepository_GetByNameAndVersion_Call { _c.Call.Return(run) return _c } @@ -674,23 +522,23 @@ func (_c *MockTokenizationKeyRepository_HardDelete_Call) RunAndReturn(run func(c } // ListCursor provides a mock function for the type MockTokenizationKeyRepository -func (_mock *MockTokenizationKeyRepository) ListCursor(ctx context.Context, afterName *string, limit int) ([]*domain0.TokenizationKey, error) { +func (_mock *MockTokenizationKeyRepository) ListCursor(ctx context.Context, afterName *string, limit int) ([]*domain.TokenizationKey, error) { ret := _mock.Called(ctx, afterName, limit) if len(ret) == 0 { panic("no return value specified for ListCursor") } - var r0 []*domain0.TokenizationKey + var r0 []*domain.TokenizationKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain0.TokenizationKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain.TokenizationKey, error)); ok { return returnFunc(ctx, afterName, limit) } - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain0.TokenizationKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain.TokenizationKey); ok { r0 = returnFunc(ctx, afterName, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*domain0.TokenizationKey) + r0 = ret.Get(0).([]*domain.TokenizationKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, *string, int) error); ok { @@ -737,12 +585,12 @@ func (_c *MockTokenizationKeyRepository_ListCursor_Call) Run(run func(ctx contex return _c } -func (_c *MockTokenizationKeyRepository_ListCursor_Call) Return(tokenizationKeys []*domain0.TokenizationKey, err error) *MockTokenizationKeyRepository_ListCursor_Call { +func (_c *MockTokenizationKeyRepository_ListCursor_Call) Return(tokenizationKeys []*domain.TokenizationKey, err error) *MockTokenizationKeyRepository_ListCursor_Call { _c.Call.Return(tokenizationKeys, err) return _c } -func (_c *MockTokenizationKeyRepository_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterName *string, limit int) ([]*domain0.TokenizationKey, error)) *MockTokenizationKeyRepository_ListCursor_Call { +func (_c *MockTokenizationKeyRepository_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterName *string, limit int) ([]*domain.TokenizationKey, error)) *MockTokenizationKeyRepository_ListCursor_Call { _c.Call.Return(run) return _c } @@ -841,7 +689,7 @@ func (_c *MockTokenRepository_CountExpired_Call) RunAndReturn(run func(ctx conte } // Create provides a mock function for the type MockTokenRepository -func (_mock *MockTokenRepository) Create(ctx context.Context, token *domain0.Token) error { +func (_mock *MockTokenRepository) Create(ctx context.Context, token *domain.Token) error { ret := _mock.Called(ctx, token) if len(ret) == 0 { @@ -849,7 +697,7 @@ func (_mock *MockTokenRepository) Create(ctx context.Context, token *domain0.Tok } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain0.Token) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.Token) error); ok { r0 = returnFunc(ctx, token) } else { r0 = ret.Error(0) @@ -864,20 +712,20 @@ type MockTokenRepository_Create_Call struct { // Create is a helper method to define mock.On call // - ctx context.Context -// - token *domain0.Token +// - token *domain.Token func (_e *MockTokenRepository_Expecter) Create(ctx interface{}, token interface{}) *MockTokenRepository_Create_Call { return &MockTokenRepository_Create_Call{Call: _e.mock.On("Create", ctx, token)} } -func (_c *MockTokenRepository_Create_Call) Run(run func(ctx context.Context, token *domain0.Token)) *MockTokenRepository_Create_Call { +func (_c *MockTokenRepository_Create_Call) Run(run func(ctx context.Context, token *domain.Token)) *MockTokenRepository_Create_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 *domain0.Token + var arg1 *domain.Token if args[1] != nil { - arg1 = args[1].(*domain0.Token) + arg1 = args[1].(*domain.Token) } run( arg0, @@ -892,13 +740,13 @@ func (_c *MockTokenRepository_Create_Call) Return(err error) *MockTokenRepositor return _c } -func (_c *MockTokenRepository_Create_Call) RunAndReturn(run func(ctx context.Context, token *domain0.Token) error) *MockTokenRepository_Create_Call { +func (_c *MockTokenRepository_Create_Call) RunAndReturn(run func(ctx context.Context, token *domain.Token) error) *MockTokenRepository_Create_Call { _c.Call.Return(run) return _c } // CreateBatch provides a mock function for the type MockTokenRepository -func (_mock *MockTokenRepository) CreateBatch(ctx context.Context, tokens []*domain0.Token) error { +func (_mock *MockTokenRepository) CreateBatch(ctx context.Context, tokens []*domain.Token) error { ret := _mock.Called(ctx, tokens) if len(ret) == 0 { @@ -906,7 +754,7 @@ func (_mock *MockTokenRepository) CreateBatch(ctx context.Context, tokens []*dom } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, []*domain0.Token) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, []*domain.Token) error); ok { r0 = returnFunc(ctx, tokens) } else { r0 = ret.Error(0) @@ -921,20 +769,20 @@ type MockTokenRepository_CreateBatch_Call struct { // CreateBatch is a helper method to define mock.On call // - ctx context.Context -// - tokens []*domain0.Token +// - tokens []*domain.Token func (_e *MockTokenRepository_Expecter) CreateBatch(ctx interface{}, tokens interface{}) *MockTokenRepository_CreateBatch_Call { return &MockTokenRepository_CreateBatch_Call{Call: _e.mock.On("CreateBatch", ctx, tokens)} } -func (_c *MockTokenRepository_CreateBatch_Call) Run(run func(ctx context.Context, tokens []*domain0.Token)) *MockTokenRepository_CreateBatch_Call { +func (_c *MockTokenRepository_CreateBatch_Call) Run(run func(ctx context.Context, tokens []*domain.Token)) *MockTokenRepository_CreateBatch_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 []*domain0.Token + var arg1 []*domain.Token if args[1] != nil { - arg1 = args[1].([]*domain0.Token) + arg1 = args[1].([]*domain.Token) } run( arg0, @@ -949,7 +797,7 @@ func (_c *MockTokenRepository_CreateBatch_Call) Return(err error) *MockTokenRepo return _c } -func (_c *MockTokenRepository_CreateBatch_Call) RunAndReturn(run func(ctx context.Context, tokens []*domain0.Token) error) *MockTokenRepository_CreateBatch_Call { +func (_c *MockTokenRepository_CreateBatch_Call) RunAndReturn(run func(ctx context.Context, tokens []*domain.Token) error) *MockTokenRepository_CreateBatch_Call { _c.Call.Return(run) return _c } @@ -1021,23 +869,23 @@ func (_c *MockTokenRepository_DeleteExpired_Call) RunAndReturn(run func(ctx cont } // GetBatchByTokens provides a mock function for the type MockTokenRepository -func (_mock *MockTokenRepository) GetBatchByTokens(ctx context.Context, tokens []string) ([]*domain0.Token, error) { +func (_mock *MockTokenRepository) GetBatchByTokens(ctx context.Context, tokens []string) ([]*domain.Token, error) { ret := _mock.Called(ctx, tokens) if len(ret) == 0 { panic("no return value specified for GetBatchByTokens") } - var r0 []*domain0.Token + var r0 []*domain.Token var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, []string) ([]*domain0.Token, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, []string) ([]*domain.Token, error)); ok { return returnFunc(ctx, tokens) } - if returnFunc, ok := ret.Get(0).(func(context.Context, []string) []*domain0.Token); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, []string) []*domain.Token); ok { r0 = returnFunc(ctx, tokens) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*domain0.Token) + r0 = ret.Get(0).([]*domain.Token) } } if returnFunc, ok := ret.Get(1).(func(context.Context, []string) error); ok { @@ -1078,34 +926,34 @@ func (_c *MockTokenRepository_GetBatchByTokens_Call) Run(run func(ctx context.Co return _c } -func (_c *MockTokenRepository_GetBatchByTokens_Call) Return(tokens1 []*domain0.Token, err error) *MockTokenRepository_GetBatchByTokens_Call { +func (_c *MockTokenRepository_GetBatchByTokens_Call) Return(tokens1 []*domain.Token, err error) *MockTokenRepository_GetBatchByTokens_Call { _c.Call.Return(tokens1, err) return _c } -func (_c *MockTokenRepository_GetBatchByTokens_Call) RunAndReturn(run func(ctx context.Context, tokens []string) ([]*domain0.Token, error)) *MockTokenRepository_GetBatchByTokens_Call { +func (_c *MockTokenRepository_GetBatchByTokens_Call) RunAndReturn(run func(ctx context.Context, tokens []string) ([]*domain.Token, error)) *MockTokenRepository_GetBatchByTokens_Call { _c.Call.Return(run) return _c } // GetByToken provides a mock function for the type MockTokenRepository -func (_mock *MockTokenRepository) GetByToken(ctx context.Context, token string) (*domain0.Token, error) { +func (_mock *MockTokenRepository) GetByToken(ctx context.Context, token string) (*domain.Token, error) { ret := _mock.Called(ctx, token) if len(ret) == 0 { panic("no return value specified for GetByToken") } - var r0 *domain0.Token + var r0 *domain.Token var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain0.Token, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain.Token, error)); ok { return returnFunc(ctx, token) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain0.Token); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain.Token); ok { r0 = returnFunc(ctx, token) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.Token) + r0 = ret.Get(0).(*domain.Token) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { @@ -1146,34 +994,34 @@ func (_c *MockTokenRepository_GetByToken_Call) Run(run func(ctx context.Context, return _c } -func (_c *MockTokenRepository_GetByToken_Call) Return(token1 *domain0.Token, err error) *MockTokenRepository_GetByToken_Call { +func (_c *MockTokenRepository_GetByToken_Call) Return(token1 *domain.Token, err error) *MockTokenRepository_GetByToken_Call { _c.Call.Return(token1, err) return _c } -func (_c *MockTokenRepository_GetByToken_Call) RunAndReturn(run func(ctx context.Context, token string) (*domain0.Token, error)) *MockTokenRepository_GetByToken_Call { +func (_c *MockTokenRepository_GetByToken_Call) RunAndReturn(run func(ctx context.Context, token string) (*domain.Token, error)) *MockTokenRepository_GetByToken_Call { _c.Call.Return(run) return _c } // GetByValueHash provides a mock function for the type MockTokenRepository -func (_mock *MockTokenRepository) GetByValueHash(ctx context.Context, keyID uuid.UUID, valueHash string) (*domain0.Token, error) { +func (_mock *MockTokenRepository) GetByValueHash(ctx context.Context, keyID uuid.UUID, valueHash string) (*domain.Token, error) { ret := _mock.Called(ctx, keyID, valueHash) if len(ret) == 0 { panic("no return value specified for GetByValueHash") } - var r0 *domain0.Token + var r0 *domain.Token var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID, string) (*domain0.Token, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID, string) (*domain.Token, error)); ok { return returnFunc(ctx, keyID, valueHash) } - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID, string) *domain0.Token); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID, string) *domain.Token); ok { r0 = returnFunc(ctx, keyID, valueHash) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.Token) + r0 = ret.Get(0).(*domain.Token) } } if returnFunc, ok := ret.Get(1).(func(context.Context, uuid.UUID, string) error); ok { @@ -1220,12 +1068,12 @@ func (_c *MockTokenRepository_GetByValueHash_Call) Run(run func(ctx context.Cont return _c } -func (_c *MockTokenRepository_GetByValueHash_Call) Return(token *domain0.Token, err error) *MockTokenRepository_GetByValueHash_Call { +func (_c *MockTokenRepository_GetByValueHash_Call) Return(token *domain.Token, err error) *MockTokenRepository_GetByValueHash_Call { _c.Call.Return(token, err) return _c } -func (_c *MockTokenRepository_GetByValueHash_Call) RunAndReturn(run func(ctx context.Context, keyID uuid.UUID, valueHash string) (*domain0.Token, error)) *MockTokenRepository_GetByValueHash_Call { +func (_c *MockTokenRepository_GetByValueHash_Call) RunAndReturn(run func(ctx context.Context, keyID uuid.UUID, valueHash string) (*domain.Token, error)) *MockTokenRepository_GetByValueHash_Call { _c.Call.Return(run) return _c } @@ -1315,26 +1163,26 @@ func (_m *MockTokenizationKeyUseCase) EXPECT() *MockTokenizationKeyUseCase_Expec } // Create provides a mock function for the type MockTokenizationKeyUseCase -func (_mock *MockTokenizationKeyUseCase) Create(ctx context.Context, name string, formatType domain0.FormatType, isDeterministic bool, alg domain.Algorithm) (*domain0.TokenizationKey, error) { +func (_mock *MockTokenizationKeyUseCase) Create(ctx context.Context, name string, formatType domain.FormatType, isDeterministic bool, alg domain0.Algorithm) (*domain.TokenizationKey, error) { ret := _mock.Called(ctx, name, formatType, isDeterministic, alg) if len(ret) == 0 { panic("no return value specified for Create") } - var r0 *domain0.TokenizationKey + var r0 *domain.TokenizationKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain0.FormatType, bool, domain.Algorithm) (*domain0.TokenizationKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain.FormatType, bool, domain0.Algorithm) (*domain.TokenizationKey, error)); ok { return returnFunc(ctx, name, formatType, isDeterministic, alg) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain0.FormatType, bool, domain.Algorithm) *domain0.TokenizationKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain.FormatType, bool, domain0.Algorithm) *domain.TokenizationKey); ok { r0 = returnFunc(ctx, name, formatType, isDeterministic, alg) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TokenizationKey) + r0 = ret.Get(0).(*domain.TokenizationKey) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, domain0.FormatType, bool, domain.Algorithm) error); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, domain.FormatType, bool, domain0.Algorithm) error); ok { r1 = returnFunc(ctx, name, formatType, isDeterministic, alg) } else { r1 = ret.Error(1) @@ -1350,14 +1198,14 @@ type MockTokenizationKeyUseCase_Create_Call struct { // Create is a helper method to define mock.On call // - ctx context.Context // - name string -// - formatType domain0.FormatType +// - formatType domain.FormatType // - isDeterministic bool -// - alg domain.Algorithm +// - alg domain0.Algorithm func (_e *MockTokenizationKeyUseCase_Expecter) Create(ctx interface{}, name interface{}, formatType interface{}, isDeterministic interface{}, alg interface{}) *MockTokenizationKeyUseCase_Create_Call { return &MockTokenizationKeyUseCase_Create_Call{Call: _e.mock.On("Create", ctx, name, formatType, isDeterministic, alg)} } -func (_c *MockTokenizationKeyUseCase_Create_Call) Run(run func(ctx context.Context, name string, formatType domain0.FormatType, isDeterministic bool, alg domain.Algorithm)) *MockTokenizationKeyUseCase_Create_Call { +func (_c *MockTokenizationKeyUseCase_Create_Call) Run(run func(ctx context.Context, name string, formatType domain.FormatType, isDeterministic bool, alg domain0.Algorithm)) *MockTokenizationKeyUseCase_Create_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -1367,17 +1215,17 @@ func (_c *MockTokenizationKeyUseCase_Create_Call) Run(run func(ctx context.Conte if args[1] != nil { arg1 = args[1].(string) } - var arg2 domain0.FormatType + var arg2 domain.FormatType if args[2] != nil { - arg2 = args[2].(domain0.FormatType) + arg2 = args[2].(domain.FormatType) } var arg3 bool if args[3] != nil { arg3 = args[3].(bool) } - var arg4 domain.Algorithm + var arg4 domain0.Algorithm if args[4] != nil { - arg4 = args[4].(domain.Algorithm) + arg4 = args[4].(domain0.Algorithm) } run( arg0, @@ -1390,12 +1238,12 @@ func (_c *MockTokenizationKeyUseCase_Create_Call) Run(run func(ctx context.Conte return _c } -func (_c *MockTokenizationKeyUseCase_Create_Call) Return(tokenizationKey *domain0.TokenizationKey, err error) *MockTokenizationKeyUseCase_Create_Call { +func (_c *MockTokenizationKeyUseCase_Create_Call) Return(tokenizationKey *domain.TokenizationKey, err error) *MockTokenizationKeyUseCase_Create_Call { _c.Call.Return(tokenizationKey, err) return _c } -func (_c *MockTokenizationKeyUseCase_Create_Call) RunAndReturn(run func(ctx context.Context, name string, formatType domain0.FormatType, isDeterministic bool, alg domain.Algorithm) (*domain0.TokenizationKey, error)) *MockTokenizationKeyUseCase_Create_Call { +func (_c *MockTokenizationKeyUseCase_Create_Call) RunAndReturn(run func(ctx context.Context, name string, formatType domain.FormatType, isDeterministic bool, alg domain0.Algorithm) (*domain.TokenizationKey, error)) *MockTokenizationKeyUseCase_Create_Call { _c.Call.Return(run) return _c } @@ -1458,23 +1306,23 @@ func (_c *MockTokenizationKeyUseCase_Delete_Call) RunAndReturn(run func(ctx cont } // GetByName provides a mock function for the type MockTokenizationKeyUseCase -func (_mock *MockTokenizationKeyUseCase) GetByName(ctx context.Context, name string) (*domain0.TokenizationKey, error) { +func (_mock *MockTokenizationKeyUseCase) GetByName(ctx context.Context, name string) (*domain.TokenizationKey, error) { ret := _mock.Called(ctx, name) if len(ret) == 0 { panic("no return value specified for GetByName") } - var r0 *domain0.TokenizationKey + var r0 *domain.TokenizationKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain0.TokenizationKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain.TokenizationKey, error)); ok { return returnFunc(ctx, name) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain0.TokenizationKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain.TokenizationKey); ok { r0 = returnFunc(ctx, name) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TokenizationKey) + r0 = ret.Get(0).(*domain.TokenizationKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { @@ -1515,34 +1363,34 @@ func (_c *MockTokenizationKeyUseCase_GetByName_Call) Run(run func(ctx context.Co return _c } -func (_c *MockTokenizationKeyUseCase_GetByName_Call) Return(tokenizationKey *domain0.TokenizationKey, err error) *MockTokenizationKeyUseCase_GetByName_Call { +func (_c *MockTokenizationKeyUseCase_GetByName_Call) Return(tokenizationKey *domain.TokenizationKey, err error) *MockTokenizationKeyUseCase_GetByName_Call { _c.Call.Return(tokenizationKey, err) return _c } -func (_c *MockTokenizationKeyUseCase_GetByName_Call) RunAndReturn(run func(ctx context.Context, name string) (*domain0.TokenizationKey, error)) *MockTokenizationKeyUseCase_GetByName_Call { +func (_c *MockTokenizationKeyUseCase_GetByName_Call) RunAndReturn(run func(ctx context.Context, name string) (*domain.TokenizationKey, error)) *MockTokenizationKeyUseCase_GetByName_Call { _c.Call.Return(run) return _c } // ListCursor provides a mock function for the type MockTokenizationKeyUseCase -func (_mock *MockTokenizationKeyUseCase) ListCursor(ctx context.Context, afterName *string, limit int) ([]*domain0.TokenizationKey, error) { +func (_mock *MockTokenizationKeyUseCase) ListCursor(ctx context.Context, afterName *string, limit int) ([]*domain.TokenizationKey, error) { ret := _mock.Called(ctx, afterName, limit) if len(ret) == 0 { panic("no return value specified for ListCursor") } - var r0 []*domain0.TokenizationKey + var r0 []*domain.TokenizationKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain0.TokenizationKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain.TokenizationKey, error)); ok { return returnFunc(ctx, afterName, limit) } - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain0.TokenizationKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain.TokenizationKey); ok { r0 = returnFunc(ctx, afterName, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*domain0.TokenizationKey) + r0 = ret.Get(0).([]*domain.TokenizationKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, *string, int) error); ok { @@ -1589,12 +1437,12 @@ func (_c *MockTokenizationKeyUseCase_ListCursor_Call) Run(run func(ctx context.C return _c } -func (_c *MockTokenizationKeyUseCase_ListCursor_Call) Return(tokenizationKeys []*domain0.TokenizationKey, err error) *MockTokenizationKeyUseCase_ListCursor_Call { +func (_c *MockTokenizationKeyUseCase_ListCursor_Call) Return(tokenizationKeys []*domain.TokenizationKey, err error) *MockTokenizationKeyUseCase_ListCursor_Call { _c.Call.Return(tokenizationKeys, err) return _c } -func (_c *MockTokenizationKeyUseCase_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterName *string, limit int) ([]*domain0.TokenizationKey, error)) *MockTokenizationKeyUseCase_ListCursor_Call { +func (_c *MockTokenizationKeyUseCase_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterName *string, limit int) ([]*domain.TokenizationKey, error)) *MockTokenizationKeyUseCase_ListCursor_Call { _c.Call.Return(run) return _c } @@ -1672,26 +1520,26 @@ func (_c *MockTokenizationKeyUseCase_PurgeDeleted_Call) RunAndReturn(run func(ct } // Rotate provides a mock function for the type MockTokenizationKeyUseCase -func (_mock *MockTokenizationKeyUseCase) Rotate(ctx context.Context, name string, formatType domain0.FormatType, isDeterministic bool, alg domain.Algorithm) (*domain0.TokenizationKey, error) { +func (_mock *MockTokenizationKeyUseCase) Rotate(ctx context.Context, name string, formatType domain.FormatType, isDeterministic bool, alg domain0.Algorithm) (*domain.TokenizationKey, error) { ret := _mock.Called(ctx, name, formatType, isDeterministic, alg) if len(ret) == 0 { panic("no return value specified for Rotate") } - var r0 *domain0.TokenizationKey + var r0 *domain.TokenizationKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain0.FormatType, bool, domain.Algorithm) (*domain0.TokenizationKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain.FormatType, bool, domain0.Algorithm) (*domain.TokenizationKey, error)); ok { return returnFunc(ctx, name, formatType, isDeterministic, alg) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain0.FormatType, bool, domain.Algorithm) *domain0.TokenizationKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain.FormatType, bool, domain0.Algorithm) *domain.TokenizationKey); ok { r0 = returnFunc(ctx, name, formatType, isDeterministic, alg) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TokenizationKey) + r0 = ret.Get(0).(*domain.TokenizationKey) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, domain0.FormatType, bool, domain.Algorithm) error); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, domain.FormatType, bool, domain0.Algorithm) error); ok { r1 = returnFunc(ctx, name, formatType, isDeterministic, alg) } else { r1 = ret.Error(1) @@ -1707,14 +1555,14 @@ type MockTokenizationKeyUseCase_Rotate_Call struct { // Rotate is a helper method to define mock.On call // - ctx context.Context // - name string -// - formatType domain0.FormatType +// - formatType domain.FormatType // - isDeterministic bool -// - alg domain.Algorithm +// - alg domain0.Algorithm func (_e *MockTokenizationKeyUseCase_Expecter) Rotate(ctx interface{}, name interface{}, formatType interface{}, isDeterministic interface{}, alg interface{}) *MockTokenizationKeyUseCase_Rotate_Call { return &MockTokenizationKeyUseCase_Rotate_Call{Call: _e.mock.On("Rotate", ctx, name, formatType, isDeterministic, alg)} } -func (_c *MockTokenizationKeyUseCase_Rotate_Call) Run(run func(ctx context.Context, name string, formatType domain0.FormatType, isDeterministic bool, alg domain.Algorithm)) *MockTokenizationKeyUseCase_Rotate_Call { +func (_c *MockTokenizationKeyUseCase_Rotate_Call) Run(run func(ctx context.Context, name string, formatType domain.FormatType, isDeterministic bool, alg domain0.Algorithm)) *MockTokenizationKeyUseCase_Rotate_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -1724,17 +1572,17 @@ func (_c *MockTokenizationKeyUseCase_Rotate_Call) Run(run func(ctx context.Conte if args[1] != nil { arg1 = args[1].(string) } - var arg2 domain0.FormatType + var arg2 domain.FormatType if args[2] != nil { - arg2 = args[2].(domain0.FormatType) + arg2 = args[2].(domain.FormatType) } var arg3 bool if args[3] != nil { arg3 = args[3].(bool) } - var arg4 domain.Algorithm + var arg4 domain0.Algorithm if args[4] != nil { - arg4 = args[4].(domain.Algorithm) + arg4 = args[4].(domain0.Algorithm) } run( arg0, @@ -1747,12 +1595,12 @@ func (_c *MockTokenizationKeyUseCase_Rotate_Call) Run(run func(ctx context.Conte return _c } -func (_c *MockTokenizationKeyUseCase_Rotate_Call) Return(tokenizationKey *domain0.TokenizationKey, err error) *MockTokenizationKeyUseCase_Rotate_Call { +func (_c *MockTokenizationKeyUseCase_Rotate_Call) Return(tokenizationKey *domain.TokenizationKey, err error) *MockTokenizationKeyUseCase_Rotate_Call { _c.Call.Return(tokenizationKey, err) return _c } -func (_c *MockTokenizationKeyUseCase_Rotate_Call) RunAndReturn(run func(ctx context.Context, name string, formatType domain0.FormatType, isDeterministic bool, alg domain.Algorithm) (*domain0.TokenizationKey, error)) *MockTokenizationKeyUseCase_Rotate_Call { +func (_c *MockTokenizationKeyUseCase_Rotate_Call) RunAndReturn(run func(ctx context.Context, name string, formatType domain.FormatType, isDeterministic bool, alg domain0.Algorithm) (*domain.TokenizationKey, error)) *MockTokenizationKeyUseCase_Rotate_Call { _c.Call.Return(run) return _c } @@ -2066,23 +1914,23 @@ func (_c *MockTokenizationUseCase_Revoke_Call) RunAndReturn(run func(ctx context } // Tokenize provides a mock function for the type MockTokenizationUseCase -func (_mock *MockTokenizationUseCase) Tokenize(ctx context.Context, keyName string, plaintext []byte, metadata map[string]any, expiresAt *time.Time) (*domain0.Token, error) { +func (_mock *MockTokenizationUseCase) Tokenize(ctx context.Context, keyName string, plaintext []byte, metadata map[string]any, expiresAt *time.Time) (*domain.Token, error) { ret := _mock.Called(ctx, keyName, plaintext, metadata, expiresAt) if len(ret) == 0 { panic("no return value specified for Tokenize") } - var r0 *domain0.Token + var r0 *domain.Token var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte, map[string]any, *time.Time) (*domain0.Token, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte, map[string]any, *time.Time) (*domain.Token, error)); ok { return returnFunc(ctx, keyName, plaintext, metadata, expiresAt) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte, map[string]any, *time.Time) *domain0.Token); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte, map[string]any, *time.Time) *domain.Token); ok { r0 = returnFunc(ctx, keyName, plaintext, metadata, expiresAt) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.Token) + r0 = ret.Get(0).(*domain.Token) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string, []byte, map[string]any, *time.Time) error); ok { @@ -2141,34 +1989,34 @@ func (_c *MockTokenizationUseCase_Tokenize_Call) Run(run func(ctx context.Contex return _c } -func (_c *MockTokenizationUseCase_Tokenize_Call) Return(token *domain0.Token, err error) *MockTokenizationUseCase_Tokenize_Call { +func (_c *MockTokenizationUseCase_Tokenize_Call) Return(token *domain.Token, err error) *MockTokenizationUseCase_Tokenize_Call { _c.Call.Return(token, err) return _c } -func (_c *MockTokenizationUseCase_Tokenize_Call) RunAndReturn(run func(ctx context.Context, keyName string, plaintext []byte, metadata map[string]any, expiresAt *time.Time) (*domain0.Token, error)) *MockTokenizationUseCase_Tokenize_Call { +func (_c *MockTokenizationUseCase_Tokenize_Call) RunAndReturn(run func(ctx context.Context, keyName string, plaintext []byte, metadata map[string]any, expiresAt *time.Time) (*domain.Token, error)) *MockTokenizationUseCase_Tokenize_Call { _c.Call.Return(run) return _c } // TokenizeBatch provides a mock function for the type MockTokenizationUseCase -func (_mock *MockTokenizationUseCase) TokenizeBatch(ctx context.Context, keyName string, plaintexts [][]byte, metadatas []map[string]any, expiresAt *time.Time) ([]*domain0.Token, error) { +func (_mock *MockTokenizationUseCase) TokenizeBatch(ctx context.Context, keyName string, plaintexts [][]byte, metadatas []map[string]any, expiresAt *time.Time) ([]*domain.Token, error) { ret := _mock.Called(ctx, keyName, plaintexts, metadatas, expiresAt) if len(ret) == 0 { panic("no return value specified for TokenizeBatch") } - var r0 []*domain0.Token + var r0 []*domain.Token var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, [][]byte, []map[string]any, *time.Time) ([]*domain0.Token, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, [][]byte, []map[string]any, *time.Time) ([]*domain.Token, error)); ok { return returnFunc(ctx, keyName, plaintexts, metadatas, expiresAt) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, [][]byte, []map[string]any, *time.Time) []*domain0.Token); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, [][]byte, []map[string]any, *time.Time) []*domain.Token); ok { r0 = returnFunc(ctx, keyName, plaintexts, metadatas, expiresAt) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*domain0.Token) + r0 = ret.Get(0).([]*domain.Token) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string, [][]byte, []map[string]any, *time.Time) error); ok { @@ -2227,12 +2075,12 @@ func (_c *MockTokenizationUseCase_TokenizeBatch_Call) Run(run func(ctx context.C return _c } -func (_c *MockTokenizationUseCase_TokenizeBatch_Call) Return(tokens []*domain0.Token, err error) *MockTokenizationUseCase_TokenizeBatch_Call { +func (_c *MockTokenizationUseCase_TokenizeBatch_Call) Return(tokens []*domain.Token, err error) *MockTokenizationUseCase_TokenizeBatch_Call { _c.Call.Return(tokens, err) return _c } -func (_c *MockTokenizationUseCase_TokenizeBatch_Call) RunAndReturn(run func(ctx context.Context, keyName string, plaintexts [][]byte, metadatas []map[string]any, expiresAt *time.Time) ([]*domain0.Token, error)) *MockTokenizationUseCase_TokenizeBatch_Call { +func (_c *MockTokenizationUseCase_TokenizeBatch_Call) RunAndReturn(run func(ctx context.Context, keyName string, plaintexts [][]byte, metadatas []map[string]any, expiresAt *time.Time) ([]*domain.Token, error)) *MockTokenizationUseCase_TokenizeBatch_Call { _c.Call.Return(run) return _c } diff --git a/internal/tokenization/usecase/tokenization_key_usecase.go b/internal/tokenization/usecase/tokenization_key_usecase.go index 7b75fff..50bf88f 100644 --- a/internal/tokenization/usecase/tokenization_key_usecase.go +++ b/internal/tokenization/usecase/tokenization_key_usecase.go @@ -8,9 +8,9 @@ import ( "github.com/google/uuid" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoService "github.com/allisson/secrets/internal/crypto/service" "github.com/allisson/secrets/internal/database" apperrors "github.com/allisson/secrets/internal/errors" + "github.com/allisson/secrets/internal/keyring" tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" ) @@ -18,13 +18,11 @@ import ( type tokenizationKeyUseCase struct { txManager database.TxManager tokenizationKeyRepo TokenizationKeyRepository - dekRepo DekRepository - keyManager cryptoService.KeyManager - kekChain *cryptoDomain.KekChain + keyring keyring.Keyring } -// createTokenizationKey is a helper that creates a tokenization key within an existing transaction context. -// It does NOT create its own transaction - the caller must handle transaction management. +// createTokenizationKey is a helper that creates a tokenization key within an existing +// transaction context. It does NOT create its own transaction; the caller manages it. func (t *tokenizationKeyUseCase) createTokenizationKey( ctx context.Context, name string, @@ -33,30 +31,16 @@ func (t *tokenizationKeyUseCase) createTokenizationKey( isDeterministic bool, alg cryptoDomain.Algorithm, ) (*tokenizationDomain.TokenizationKey, error) { - // Get active KEK from chain - activeKek, err := getKek(t.kekChain, t.kekChain.ActiveKekID()) + handle, err := t.keyring.AllocateDek(ctx, alg) if err != nil { - return nil, apperrors.Wrap(err, "failed to get active KEK") + return nil, apperrors.Wrap(err, "failed to allocate DEK") } - // Create DEK encrypted with active KEK - dek, err := t.keyManager.CreateDek(activeKek, alg) - if err != nil { - return nil, apperrors.Wrap(err, "failed to create DEK") - } - - // Persist DEK to database - if err := t.dekRepo.Create(ctx, &dek); err != nil { - return nil, apperrors.Wrap(err, "failed to persist DEK") - } - - // Create tokenization key keyID, err := uuid.NewV7() if err != nil { return nil, apperrors.Wrap(err, "failed to generate UUID for tokenization key") } - // Generate salt for deterministic hashing salt := make([]byte, 32) if _, err := rand.Read(salt); err != nil { return nil, apperrors.Wrap(err, "failed to generate salt") @@ -69,16 +53,14 @@ func (t *tokenizationKeyUseCase) createTokenizationKey( FormatType: formatType, IsDeterministic: isDeterministic, Salt: salt, - DekID: dek.ID, + DekID: handle.DekID, CreatedAt: time.Now().UTC(), } - // Validate tokenization key fields if err := tokenizationKey.Validate(); err != nil { return nil, apperrors.Wrap(err, "tokenization key validation failed") } - // Persist tokenization key if err := t.tokenizationKeyRepo.Create(ctx, tokenizationKey); err != nil { return nil, apperrors.Wrap(err, "failed to persist tokenization key") } @@ -87,7 +69,6 @@ func (t *tokenizationKeyUseCase) createTokenizationKey( } // Create generates and persists a new tokenization key with version 1. -// Returns ErrTokenizationKeyAlreadyExists if a key with the same name already exists. func (t *tokenizationKeyUseCase) Create( ctx context.Context, name string, @@ -95,12 +76,10 @@ func (t *tokenizationKeyUseCase) Create( isDeterministic bool, alg cryptoDomain.Algorithm, ) (*tokenizationDomain.TokenizationKey, error) { - // Validate format type if err := formatType.Validate(); err != nil { return nil, tokenizationDomain.ErrInvalidFormatType } - // Check if tokenization key with version 1 already exists existingKey, err := t.tokenizationKeyRepo.GetByNameAndVersion(ctx, name, 1) if err != nil && !apperrors.Is(err, tokenizationDomain.ErrTokenizationKeyNotFound) { return nil, apperrors.Wrap(err, "failed to check for existing tokenization key") @@ -110,13 +89,10 @@ func (t *tokenizationKeyUseCase) Create( } var tokenizationKey *tokenizationDomain.TokenizationKey - - // Wrap DEK and tokenization key creation in a transaction err = t.txManager.WithTx(ctx, func(txCtx context.Context) error { tokenizationKey, err = t.createTokenizationKey(txCtx, name, 1, formatType, isDeterministic, alg) return err }) - if err != nil { return nil, apperrors.Wrap(err, "failed to create tokenization key") } @@ -124,8 +100,7 @@ func (t *tokenizationKeyUseCase) Create( return tokenizationKey, nil } -// Rotate creates a new version of an existing tokenization key by incrementing the version number. -// Generates a new DEK for the new version while preserving old versions for detokenization. +// Rotate creates a new version of an existing tokenization key. // If the key doesn't exist, it creates the first version. func (t *tokenizationKeyUseCase) Rotate( ctx context.Context, @@ -134,26 +109,20 @@ func (t *tokenizationKeyUseCase) Rotate( isDeterministic bool, alg cryptoDomain.Algorithm, ) (*tokenizationDomain.TokenizationKey, error) { - // Validate format type if err := formatType.Validate(); err != nil { return nil, tokenizationDomain.ErrInvalidFormatType } var newKey *tokenizationDomain.TokenizationKey - err := t.txManager.WithTx(ctx, func(txCtx context.Context) error { - // Get latest tokenization key version currentKey, err := t.tokenizationKeyRepo.GetByName(txCtx, name) if err != nil { - // If key doesn't exist, create first version if apperrors.Is(err, tokenizationDomain.ErrTokenizationKeyNotFound) { newKey, err = t.createTokenizationKey(txCtx, name, 1, formatType, isDeterministic, alg) return err } return apperrors.Wrap(err, "failed to get current tokenization key") } - - // Create new tokenization key version using helper newKey, err = t.createTokenizationKey( txCtx, name, @@ -164,7 +133,6 @@ func (t *tokenizationKeyUseCase) Rotate( ) return err }) - if err != nil { return nil, apperrors.Wrap(err, "failed to rotate tokenization key") } @@ -174,15 +142,13 @@ func (t *tokenizationKeyUseCase) Rotate( // Delete soft deletes a tokenization key and all its versions by name. func (t *tokenizationKeyUseCase) Delete(ctx context.Context, name string) error { - err := t.tokenizationKeyRepo.Delete(ctx, name) - if err != nil { + if err := t.tokenizationKeyRepo.Delete(ctx, name); err != nil { return apperrors.Wrap(err, "failed to delete tokenization key") } return nil } // GetByName retrieves a single tokenization key by its name. -// Returns the latest version for the key. Filters out soft-deleted keys. func (t *tokenizationKeyUseCase) GetByName( ctx context.Context, name string, @@ -197,8 +163,7 @@ func (t *tokenizationKeyUseCase) GetByName( return key, nil } -// ListCursor retrieves tokenization keys ordered by name ascending with cursor-based pagination. -// Returns the latest version for each key name. +// ListCursor retrieves tokenization keys ordered by name ascending. func (t *tokenizationKeyUseCase) ListCursor( ctx context.Context, afterName *string, @@ -212,8 +177,6 @@ func (t *tokenizationKeyUseCase) ListCursor( } // PurgeDeleted permanently removes soft-deleted tokenization keys and their associated tokens. -// Only keys deleted longer than olderThanDays ago are affected. -// If dryRun is true, returns the count of items that would be deleted without performing the operation. func (t *tokenizationKeyUseCase) PurgeDeleted( ctx context.Context, olderThanDays int, @@ -231,15 +194,11 @@ func (t *tokenizationKeyUseCase) PurgeDeleted( func NewTokenizationKeyUseCase( txManager database.TxManager, tokenizationKeyRepo TokenizationKeyRepository, - dekRepo DekRepository, - keyManager cryptoService.KeyManager, - kekChain *cryptoDomain.KekChain, + kr keyring.Keyring, ) TokenizationKeyUseCase { return &tokenizationKeyUseCase{ txManager: txManager, tokenizationKeyRepo: tokenizationKeyRepo, - dekRepo: dekRepo, - keyManager: keyManager, - kekChain: kekChain, + keyring: kr, } } diff --git a/internal/tokenization/usecase/tokenization_key_usecase_test.go b/internal/tokenization/usecase/tokenization_key_usecase_test.go index ecadc61..45dec4a 100644 --- a/internal/tokenization/usecase/tokenization_key_usecase_test.go +++ b/internal/tokenization/usecase/tokenization_key_usecase_test.go @@ -1,875 +1,192 @@ -package usecase +package usecase_test import ( "context" - "errors" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoServiceMocks "github.com/allisson/secrets/internal/crypto/service/mocks" - databaseMocks "github.com/allisson/secrets/internal/database/mocks" + apperrors "github.com/allisson/secrets/internal/errors" + "github.com/allisson/secrets/internal/keyring" tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" - tokenizationTesting "github.com/allisson/secrets/internal/tokenization/testing" - tokenizationMocks "github.com/allisson/secrets/internal/tokenization/usecase/mocks" + "github.com/allisson/secrets/internal/tokenization/usecase" + "github.com/allisson/secrets/internal/tokenization/usecase/mocks" ) -// TestTokenizationKeyUseCase_Create tests the Create method. +// noopTxManager runs the function with no real transaction. +type noopTxManager struct{} + +func (noopTxManager) WithTx(ctx context.Context, fn func(ctx context.Context) error) error { + return fn(ctx) +} + +func newTokenizationKeyUseCase( + t *testing.T, +) (usecase.TokenizationKeyUseCase, *keyring.Fake, *mocks.MockTokenizationKeyRepository) { + t.Helper() + fake := keyring.NewFake() + repo := mocks.NewMockTokenizationKeyRepository(t) + uc := usecase.NewTokenizationKeyUseCase(noopTxManager{}, repo, fake) + return uc, fake, repo +} + func TestTokenizationKeyUseCase_Create(t *testing.T) { + t.Parallel() ctx := context.Background() t.Run("Success_CreateKeyWithUUIDFormat", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dek := cryptoDomain.Dek{ - ID: uuid.Must(uuid.NewV7()), - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). + t.Parallel() + uc, _, repo := newTokenizationKeyUseCase(t) + + repo.EXPECT(). GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). - Once() - - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - // Execute the transaction function - _ = fn(ctx) - }). - Return(nil). - Once() - - mockKeyManager.EXPECT(). - CreateDek(activeKek, cryptoDomain.AESGCM). - Return(dek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(d *cryptoDomain.Dek) bool { - return d.ID == dek.ID && d.KekID == dek.KekID + Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound) + repo.EXPECT(). + Create(mock.Anything, mock.MatchedBy(func(k *tokenizationDomain.TokenizationKey) bool { + return k.Name == "test-key" && + k.FormatType == tokenizationDomain.FormatUUID && + k.Version == 1 && + !k.IsDeterministic && + len(k.Salt) == 32 && + k.DekID != [16]byte{} })). - Return(nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(key *tokenizationDomain.TokenizationKey) bool { - return key.Name == "test-key" && - key.FormatType == tokenizationDomain.FormatUUID && - key.Version == 1 && - key.IsDeterministic == false && - key.DekID == dek.ID && - len(key.Salt) == 32 - })). - Return(nil). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.Create(ctx, "test-key", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) + Return(nil) - // Assert - assert.NoError(t, err) - assert.NotNil(t, key) + key, err := uc.Create(ctx, "test-key", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) + require.NoError(t, err) assert.Equal(t, "test-key", key.Name) - assert.Equal(t, tokenizationDomain.FormatUUID, key.FormatType) assert.Equal(t, uint(1), key.Version) - assert.False(t, key.IsDeterministic) - }) - - t.Run("Success_CreateKeyWithLuhnPreservingDeterministic", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dek := cryptoDomain.Dek{ - ID: uuid.Must(uuid.NewV7()), - KekID: activeKek.ID, - Algorithm: cryptoDomain.ChaCha20, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByNameAndVersion(ctx, "payment-cards", uint(1)). - Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). - Once() - - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - // Execute the transaction function - _ = fn(ctx) - }). - Return(nil). - Once() - - mockKeyManager.EXPECT(). - CreateDek(activeKek, cryptoDomain.ChaCha20). - Return(dek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, mock.Anything). - Return(nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(key *tokenizationDomain.TokenizationKey) bool { - return key.Name == "payment-cards" && - key.FormatType == tokenizationDomain.FormatLuhnPreserving && - key.Version == 1 && - key.IsDeterministic == true && - len(key.Salt) == 32 - })). - Return(nil). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.Create( - ctx, - "payment-cards", - tokenizationDomain.FormatLuhnPreserving, - true, - cryptoDomain.ChaCha20, - ) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, key) - assert.Equal(t, "payment-cards", key.Name) - assert.Equal(t, tokenizationDomain.FormatLuhnPreserving, key.FormatType) - assert.True(t, key.IsDeterministic) }) - t.Run("Error_KeyManagerCreateDekFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) + t.Run("Error_KeyAlreadyExists", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTokenizationKeyUseCase(t) + repo.EXPECT(). + GetByNameAndVersion(ctx, "dup", uint(1)). + Return(&tokenizationDomain.TokenizationKey{}, nil) - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - expectedError := errors.New("key manager error") - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). - Once() - - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - // Execute the transaction function - _ = fn(ctx) - }). - Return(expectedError). - Once() - - mockKeyManager.EXPECT(). - CreateDek(mock.Anything, mock.Anything). - Return(cryptoDomain.Dek{}, expectedError). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.Create(ctx, "test-key", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) - - // Assert - assert.Error(t, err) - assert.Nil(t, key) - assert.True(t, errors.Is(err, expectedError)) - assert.Contains(t, err.Error(), "failed to create tokenization key") + _, err := uc.Create(ctx, "dup", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) + assert.ErrorIs(t, err, tokenizationDomain.ErrTokenizationKeyAlreadyExists) }) - t.Run("Error_DekRepositoryCreateFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dek := cryptoDomain.Dek{ - ID: uuid.Must(uuid.NewV7()), - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - expectedError := errors.New("database error") - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). - Once() - - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - // Execute the transaction function - _ = fn(ctx) - }). - Return(expectedError). - Once() - - mockKeyManager.EXPECT(). - CreateDek(mock.Anything, mock.Anything). - Return(dek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, mock.Anything). - Return(expectedError). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.Create(ctx, "test-key", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) - - // Assert - assert.Error(t, err) - assert.Nil(t, key) - assert.True(t, errors.Is(err, expectedError)) - assert.Contains(t, err.Error(), "failed to create tokenization key") + t.Run("Error_InvalidFormatType", func(t *testing.T) { + t.Parallel() + uc, _, _ := newTokenizationKeyUseCase(t) + _, err := uc.Create(ctx, "k", tokenizationDomain.FormatType("nope"), false, cryptoDomain.AESGCM) + assert.ErrorIs(t, err, tokenizationDomain.ErrInvalidFormatType) }) - t.Run("Error_TokenizationKeyRepositoryCreateFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dek := cryptoDomain.Dek{ - ID: uuid.Must(uuid.NewV7()), - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - expectedError := errors.New("key already exists") - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). - Once() - - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - // Execute the transaction function - _ = fn(ctx) - }). - Return(expectedError). - Once() - - mockKeyManager.EXPECT(). - CreateDek(mock.Anything, mock.Anything). - Return(dek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, mock.Anything). - Return(nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - Create(mock.Anything, mock.Anything). - Return(expectedError). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.Create(ctx, "test-key", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) + t.Run("Error_KeyringAllocateFails", func(t *testing.T) { + t.Parallel() + uc, fake, repo := newTokenizationKeyUseCase(t) + fake.FailAllocate = apperrors.Wrap(apperrors.ErrInvalidInput, "kms down") - // Assert + repo.EXPECT(). + GetByNameAndVersion(ctx, "k", uint(1)). + Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound) + + _, err := uc.Create(ctx, "k", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) assert.Error(t, err) - assert.Nil(t, key) - assert.True(t, errors.Is(err, expectedError)) - assert.Contains(t, err.Error(), "failed to create tokenization key") }) } -// TestTokenizationKeyUseCase_Rotate tests the Rotate method. func TestTokenizationKeyUseCase_Rotate(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_RotateKey", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - existingKey := &tokenizationDomain.TokenizationKey{ - ID: uuid.Must(uuid.NewV7()), - Name: "test-key", - FormatType: tokenizationDomain.FormatNumeric, - Version: 1, - IsDeterministic: true, - DekID: uuid.Must(uuid.NewV7()), - } - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dek := cryptoDomain.Dek{ - ID: uuid.Must(uuid.NewV7()), - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - // Execute the transaction function - _ = fn(ctx) - }). - Return(nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - GetByName(mock.Anything, "test-key"). - Return(existingKey, nil). - Once() - - mockKeyManager.EXPECT(). - CreateDek(activeKek, cryptoDomain.AESGCM). - Return(dek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, mock.Anything). - Return(nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(key *tokenizationDomain.TokenizationKey) bool { - return key.Name == "test-key" && - key.FormatType == tokenizationDomain.FormatNumeric && - key.Version == 2 && // Version incremented - key.IsDeterministic == true && - key.DekID == dek.ID && - len(key.Salt) == 32 + t.Run("Success_IncrementsVersion", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTokenizationKeyUseCase(t) + + repo.EXPECT(). + GetByName(ctx, "k"). + Return(&tokenizationDomain.TokenizationKey{ + Name: "k", + Version: 2, + }, nil) + repo.EXPECT(). + Create(mock.Anything, mock.MatchedBy(func(k *tokenizationDomain.TokenizationKey) bool { + return k.Version == 3 })). - Return(nil). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.Rotate(ctx, "test-key", tokenizationDomain.FormatNumeric, true, cryptoDomain.AESGCM) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, key) - assert.Equal(t, "test-key", key.Name) - assert.Equal(t, uint(2), key.Version) + Return(nil) + + key, err := uc.Rotate(ctx, "k", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) + require.NoError(t, err) + assert.Equal(t, uint(3), key.Version) }) - t.Run("Success_CreateFirstKeyWhenNoneExist", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dek := cryptoDomain.Dek{ - ID: uuid.Must(uuid.NewV7()), - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - // Execute the transaction function - _ = fn(ctx) - }). - Return(nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - GetByName(mock.Anything, "new-key"). - Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). - Once() - - // Expectations for createTokenizationKey() call within transaction - mockKeyManager.EXPECT(). - CreateDek(activeKek, cryptoDomain.AESGCM). - Return(dek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, mock.Anything). - Return(nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(key *tokenizationDomain.TokenizationKey) bool { - return key.Name == "new-key" && - key.FormatType == tokenizationDomain.FormatUUID && - key.Version == 1 && - len(key.Salt) == 32 + t.Run("Success_CreatesFirstVersionWhenAbsent", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTokenizationKeyUseCase(t) + + repo.EXPECT().GetByName(ctx, "new").Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound) + repo.EXPECT(). + Create(mock.Anything, mock.MatchedBy(func(k *tokenizationDomain.TokenizationKey) bool { + return k.Version == 1 })). - Return(nil). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.Rotate(ctx, "new-key", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, key) - assert.Equal(t, "new-key", key.Name) + Return(nil) + + key, err := uc.Rotate(ctx, "new", tokenizationDomain.FormatUUID, false, cryptoDomain.AESGCM) + require.NoError(t, err) assert.Equal(t, uint(1), key.Version) }) } -// TestTokenizationKeyUseCase_Delete tests the Delete method. func TestTokenizationKeyUseCase_Delete(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_DeleteKey", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - keyName := "test-key" - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - Delete(ctx, keyName). - Return(nil). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - err := uc.Delete(ctx, keyName) - - // Assert - assert.NoError(t, err) - }) - - t.Run("Error_DeleteFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - keyName := "test-key" - expectedError := errors.New("database error") - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - Delete(ctx, keyName). - Return(expectedError). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - err := uc.Delete(ctx, keyName) - - // Assert - assert.Error(t, err) - assert.True(t, errors.Is(err, expectedError)) - assert.Contains(t, err.Error(), "failed to delete tokenization key") - }) + uc, _, repo := newTokenizationKeyUseCase(t) + repo.EXPECT().Delete(ctx, "k").Return(nil) + assert.NoError(t, uc.Delete(ctx, "k")) } -// TestTokenizationKeyUseCase_PurgeDeleted tests the PurgeDeleted method. -func TestTokenizationKeyUseCase_PurgeDeleted(t *testing.T) { +func TestTokenizationKeyUseCase_GetByName(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_PurgeDeletedKeys", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - olderThanDays := 30 - dryRun := false - expectedDeletedCount := int64(5) - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(expectedDeletedCount, nil). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.NoError(t, err) - assert.Equal(t, expectedDeletedCount, count) - }) + t.Run("Success", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTokenizationKeyUseCase(t) + want := &tokenizationDomain.TokenizationKey{Name: "k", Version: 1} + repo.EXPECT().GetByName(ctx, "k").Return(want, nil) - t.Run("Success_PurgeDeletedKeys_DryRun", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - olderThanDays := 30 - dryRun := true - expectedDeletedCount := int64(10) - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(expectedDeletedCount, nil). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.NoError(t, err) - assert.Equal(t, expectedDeletedCount, count) + got, err := uc.GetByName(ctx, "k") + require.NoError(t, err) + assert.Equal(t, want, got) }) - t.Run("Error_InvalidOlderThanDays", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - olderThanDays := -1 - dryRun := false - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.Error(t, err) - assert.Equal(t, int64(0), count) - assert.Contains(t, err.Error(), "olderThanDays must be a positive number") - }) - t.Run("Error_HardDeleteFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - olderThanDays := 30 - dryRun := false - expectedError := errors.New("database error") - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(int64(0), expectedError). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.Error(t, err) - assert.Equal(t, int64(0), count) - assert.True(t, errors.Is(err, expectedError)) + t.Run("Error_NotFound", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTokenizationKeyUseCase(t) + repo.EXPECT().GetByName(ctx, "k").Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound) + + _, err := uc.GetByName(ctx, "k") + assert.ErrorIs(t, err, tokenizationDomain.ErrTokenizationKeyNotFound) }) } -// TestTokenizationKeyUseCase_GetByName tests the GetByName method. -func TestTokenizationKeyUseCase_GetByName(t *testing.T) { +func TestTokenizationKeyUseCase_PurgeDeleted(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_GetByName", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - expectedKey := &tokenizationDomain.TokenizationKey{ - ID: uuid.Must(uuid.NewV7()), - Name: "test-key", - FormatType: tokenizationDomain.FormatUUID, - Version: 1, - IsDeterministic: false, - DekID: uuid.Must(uuid.NewV7()), - } - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(expectedKey, nil). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.GetByName(ctx, "test-key") - - // Assert - assert.NoError(t, err) - assert.Equal(t, expectedKey, key) - }) - - t.Run("Error_NotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - expectedError := tokenizationDomain.ErrTokenizationKeyNotFound - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "non-existent"). - Return(nil, expectedError). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.GetByName(ctx, "non-existent") - - // Assert + t.Run("Error_NegativeDays", func(t *testing.T) { + t.Parallel() + uc, _, _ := newTokenizationKeyUseCase(t) + _, err := uc.PurgeDeleted(ctx, -1, false) assert.Error(t, err) - assert.Nil(t, key) - assert.True(t, errors.Is(err, expectedError)) }) - t.Run("Error_RepositoryFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - expectedError := errors.New("database error") - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(nil, expectedError). - Once() - - // Execute - uc := NewTokenizationKeyUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockDekRepo, - mockKeyManager, - kekChain, - ) - key, err := uc.GetByName(ctx, "test-key") - - // Assert - assert.Error(t, err) - assert.Nil(t, key) - assert.True(t, errors.Is(err, expectedError)) - assert.Contains(t, err.Error(), "failed to get tokenization key") + t.Run("Success_DryRun", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTokenizationKeyUseCase(t) + repo.EXPECT().HardDelete(ctx, mock.Anything, true).Return(int64(3), nil) + + n, err := uc.PurgeDeleted(ctx, 30, true) + require.NoError(t, err) + assert.EqualValues(t, 3, n) }) } diff --git a/internal/tokenization/usecase/tokenization_usecase.go b/internal/tokenization/usecase/tokenization_usecase.go index 2a8cd46..a5ab88b 100644 --- a/internal/tokenization/usecase/tokenization_usecase.go +++ b/internal/tokenization/usecase/tokenization_usecase.go @@ -1,7 +1,8 @@ // Package usecase implements tokenization business logic. // // Coordinates token generation, encryption, and lifecycle management with configurable -// deterministic behavior. Uses TxManager for transactional consistency. +// deterministic behavior. Uses TxManager for transactional consistency and Keyring for +// envelope encryption. package usecase import ( @@ -11,36 +12,28 @@ import ( "github.com/google/uuid" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoService "github.com/allisson/secrets/internal/crypto/service" "github.com/allisson/secrets/internal/database" apperrors "github.com/allisson/secrets/internal/errors" + "github.com/allisson/secrets/internal/keyring" tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" tokenizationService "github.com/allisson/secrets/internal/tokenization/service" ) // validateTokenLength checks if the plaintext length is valid for the token format type. func validateTokenLength(formatType tokenizationDomain.FormatType, length int) error { - // UUID format ignores length parameter if formatType == tokenizationDomain.FormatUUID { return nil } - - // Luhn format requires at least 2 characters if formatType == tokenizationDomain.FormatLuhnPreserving && length < tokenizationDomain.MinLuhnTokenLength { return tokenizationDomain.ErrTokenLengthInvalid } - - // All format-preserving tokens have max length constraint if length > tokenizationDomain.MaxTokenLength { return tokenizationDomain.ErrTokenLengthInvalid } - - // Minimum length is 1 for numeric/alphanumeric if length < 1 { return tokenizationDomain.ErrTokenLengthInvalid } - return nil } @@ -49,20 +42,11 @@ type tokenizationUseCase struct { txManager database.TxManager tokenizationRepo TokenizationKeyRepository tokenRepo TokenRepository - dekRepo DekRepository - aeadManager cryptoService.AEADManager - keyManager cryptoService.KeyManager hashService HashService - kekChain *cryptoDomain.KekChain + keyring keyring.Keyring } // Tokenize generates a token for the given plaintext value using the latest version of the named key. -// In deterministic mode, returns the existing token if the value has been tokenized before. -// Metadata is optional display data (e.g., last 4 digits) stored unencrypted. -// -// Rate Limiting: Production systems should implement rate limiting on this method to prevent abuse. -// Recommended: 100 requests per minute per user/API key for standard use cases. -// Adjust based on your specific security requirements and usage patterns. func (t *tokenizationUseCase) Tokenize( ctx context.Context, keyName string, @@ -70,7 +54,6 @@ func (t *tokenizationUseCase) Tokenize( metadata map[string]any, expiresAt *time.Time, ) (*tokenizationDomain.Token, error) { - // Validate plaintext size if len(plaintext) == 0 { return nil, tokenizationDomain.ErrPlaintextEmpty } @@ -78,69 +61,35 @@ func (t *tokenizationUseCase) Tokenize( return nil, tokenizationDomain.ErrPlaintextTooLarge } - // Get latest tokenization key version tokenizationKey, err := t.tokenizationRepo.GetByName(ctx, keyName) if err != nil { return nil, apperrors.Wrap(err, "failed to get tokenization key by name") } - // In deterministic mode, check if token already exists for this value + // In deterministic mode, look up an existing token before encrypting. if tokenizationKey.IsDeterministic { valueHash := t.hashService.Hash(plaintext, tokenizationKey.Salt) existingToken, err := t.tokenRepo.GetByValueHash(ctx, tokenizationKey.ID, valueHash) if err != nil && !apperrors.Is(err, tokenizationDomain.ErrTokenNotFound) { return nil, apperrors.Wrap(err, "failed to check existing token in deterministic mode") } - if existingToken != nil { - // Return existing valid token - if existingToken.IsValid() { - return existingToken, nil - } - // Existing token is expired or revoked - proceed to create new token + if existingToken != nil && existingToken.IsValid() { + return existingToken, nil } } - // Get DEK by tokenization key's DekID - dek, err := t.dekRepo.Get(ctx, tokenizationKey.DekID) - if err != nil { - return nil, apperrors.Wrap(err, "failed to get DEK") - } - - // Get KEK for decrypting DEK - kek, err := getKek(t.kekChain, dek.KekID) - if err != nil { - return nil, apperrors.Wrap(err, "failed to get KEK") - } - - // Decrypt DEK with KEK - dekKey, err := t.keyManager.DecryptDek(dek, kek) - if err != nil { - return nil, apperrors.Wrap(err, "failed to decrypt DEK") - } - defer cryptoDomain.Zero(dekKey) - - // Create AEAD cipher with decrypted DEK - cipher, err := t.aeadManager.CreateCipher(dekKey, dek.Algorithm) - if err != nil { - return nil, apperrors.Wrap(err, "failed to create cipher") - } - - // Encrypt plaintext - ciphertext, nonce, err := cipher.Encrypt(plaintext, nil) + handle := keyring.DekHandle{DekID: tokenizationKey.DekID} + ciphertext, nonce, err := t.keyring.EncryptWith(ctx, handle, plaintext, nil) if err != nil { return nil, apperrors.Wrap(err, "failed to encrypt plaintext") } - // Generate token using appropriate generator generator, err := tokenizationService.NewTokenGenerator(tokenizationKey.FormatType) if err != nil { return nil, err } - // For format-preserving tokens, use plaintext length as hint tokenLength := len(plaintext) - - // Validate token length matches format requirements if err := validateTokenLength(tokenizationKey.FormatType, tokenLength); err != nil { return nil, err } @@ -150,11 +99,11 @@ func (t *tokenizationUseCase) Tokenize( return nil, apperrors.Wrap(err, "failed to generate token") } - // Create token record tokenID, err := uuid.NewV7() if err != nil { return nil, apperrors.Wrap(err, "failed to generate UUID for token") } + token := &tokenizationDomain.Token{ ID: tokenID, TokenizationKeyID: tokenizationKey.ID, @@ -168,32 +117,23 @@ func (t *tokenizationUseCase) Tokenize( RevokedAt: nil, } - // In deterministic mode, store value hash for lookup if tokenizationKey.IsDeterministic { valueHash := t.hashService.Hash(plaintext, tokenizationKey.Salt) token.ValueHash = &valueHash } - // Persist token if err := t.tokenRepo.Create(ctx, token); err != nil { - // In deterministic mode, handle race condition where another goroutine - // created the same token between our check and insert + // Race: another goroutine inserted the deterministic token between the + // existence check and the insert. Re-read and return that one. if tokenizationKey.IsDeterministic && apperrors.Is(err, apperrors.ErrConflict) { - // Race detected: another concurrent request inserted this token - // Query again to get the token that was inserted valueHash := t.hashService.Hash(plaintext, tokenizationKey.Salt) existingToken, queryErr := t.tokenRepo.GetByValueHash(ctx, tokenizationKey.ID, valueHash) if queryErr != nil { - // If query fails, return original create error return nil, apperrors.Wrap(err, "failed to create token") } - - // Validate that the concurrently created token is valid before returning it if !existingToken.IsValid() { return nil, apperrors.Wrap(err, "concurrently created token is invalid or expired") } - - // Return the token created by the concurrent request return existingToken, nil } return nil, apperrors.Wrap(err, "failed to create token") @@ -202,8 +142,7 @@ func (t *tokenizationUseCase) Tokenize( return token, nil } -// TokenizeBatch generates tokens for multiple plaintext values using the latest version of the named key. -// Wrapped in a transaction for atomicity. +// TokenizeBatch generates tokens for multiple plaintext values, wrapped in a transaction. func (t *tokenizationUseCase) TokenizeBatch( ctx context.Context, keyName string, @@ -233,61 +172,30 @@ func (t *tokenizationUseCase) TokenizeBatch( } // Detokenize retrieves the original plaintext value for a given token. -// Returns ErrTokenNotFound if token doesn't exist, ErrTokenExpired if expired, ErrTokenRevoked if revoked. -// Security Note: Callers MUST zero the returned plaintext after use: cryptoDomain.Zero(plaintext). +// Security: callers MUST zero the returned plaintext after use. func (t *tokenizationUseCase) Detokenize( ctx context.Context, token string, ) (plaintext []byte, metadata map[string]any, err error) { - // Get token record tokenRecord, err := t.tokenRepo.GetByToken(ctx, token) if err != nil { return nil, nil, apperrors.Wrap(err, "failed to get token") } - // Validate token is not expired if tokenRecord.IsExpired() { return nil, nil, tokenizationDomain.ErrTokenExpired } - - // Validate token is not revoked if tokenRecord.IsRevoked() { return nil, nil, tokenizationDomain.ErrTokenRevoked } - // Get tokenization key to retrieve its DekID tokenizationKey, err := t.tokenizationRepo.Get(ctx, tokenRecord.TokenizationKeyID) if err != nil { return nil, nil, apperrors.Wrap(err, "failed to get tokenization key") } - // Get DEK - dek, err := t.dekRepo.Get(ctx, tokenizationKey.DekID) - if err != nil { - return nil, nil, apperrors.Wrap(err, "failed to get DEK") - } - - // Get KEK for decrypting DEK - kek, err := getKek(t.kekChain, dek.KekID) - if err != nil { - return nil, nil, apperrors.Wrap(err, "failed to get KEK") - } - - // Decrypt DEK with KEK - dekKey, err := t.keyManager.DecryptDek(dek, kek) - if err != nil { - return nil, nil, apperrors.Wrap(err, "failed to decrypt DEK") - } - defer cryptoDomain.Zero(dekKey) - - // Create AEAD cipher with decrypted DEK - cipher, err := t.aeadManager.CreateCipher(dekKey, dek.Algorithm) - if err != nil { - return nil, nil, apperrors.Wrap(err, "failed to create cipher") - } - - // Decrypt ciphertext with nonce - plaintext, err = cipher.Decrypt(tokenRecord.Ciphertext, tokenRecord.Nonce, nil) + handle := keyring.DekHandle{DekID: tokenizationKey.DekID} + plaintext, err = t.keyring.DecryptWith(ctx, handle, tokenRecord.Ciphertext, tokenRecord.Nonce, nil) if err != nil { return nil, nil, apperrors.Wrap( cryptoDomain.ErrDecryptionFailed, @@ -298,8 +206,7 @@ func (t *tokenizationUseCase) Detokenize( return plaintext, tokenRecord.Metadata, nil } -// DetokenizeBatch retrieves original plaintext values for multiple tokens. -// Wrapped in a transaction for atomicity. +// DetokenizeBatch retrieves original plaintext values for multiple tokens, wrapped in a transaction. func (t *tokenizationUseCase) DetokenizeBatch( ctx context.Context, tokens []string, @@ -323,7 +230,6 @@ func (t *tokenizationUseCase) DetokenizeBatch( // Validate checks if a token exists and is valid (not expired or revoked). func (t *tokenizationUseCase) Validate(ctx context.Context, token string) (bool, error) { - // Get token record tokenRecord, err := t.tokenRepo.GetByToken(ctx, token) if err != nil { if apperrors.Is(err, tokenizationDomain.ErrTokenNotFound) { @@ -331,65 +237,46 @@ func (t *tokenizationUseCase) Validate(ctx context.Context, token string) (bool, } return false, apperrors.Wrap(err, "failed to validate token") } - - // Check if token is valid return tokenRecord.IsValid(), nil } // Revoke marks a token as revoked, preventing further detokenization. func (t *tokenizationUseCase) Revoke(ctx context.Context, token string) error { - // Verify token exists first - _, err := t.tokenRepo.GetByToken(ctx, token) - if err != nil { + if _, err := t.tokenRepo.GetByToken(ctx, token); err != nil { return apperrors.Wrap(err, "failed to get token for revocation") } - - // Revoke the token - err = t.tokenRepo.Revoke(ctx, token) - if err != nil { + if err := t.tokenRepo.Revoke(ctx, token); err != nil { return apperrors.Wrap(err, "failed to revoke token") } return nil } // CleanupExpired deletes tokens that expired more than the specified number of days ago. -// Returns the number of deleted tokens. Use dryRun=true to preview count without deletion. func (t *tokenizationUseCase) CleanupExpired(ctx context.Context, days int, dryRun bool) (int64, error) { if days < 0 { return 0, apperrors.New("days must be non-negative") } - // Calculate the cutoff timestamp (days ago from now in UTC) cutoff := time.Now().UTC().AddDate(0, 0, -days) - if dryRun { - // In dry run mode, count expired tokens without deleting return t.tokenRepo.CountExpired(ctx, cutoff) } - - // Delete expired tokens return t.tokenRepo.DeleteExpired(ctx, cutoff) } -// NewTokenizationUseCase creates a new TokenizationUseCase with injected dependencies. +// NewTokenizationUseCase creates a new TokenizationUseCase backed by a Keyring. func NewTokenizationUseCase( txManager database.TxManager, tokenizationRepo TokenizationKeyRepository, tokenRepo TokenRepository, - dekRepo DekRepository, - aeadManager cryptoService.AEADManager, - keyManager cryptoService.KeyManager, hashService HashService, - kekChain *cryptoDomain.KekChain, + kr keyring.Keyring, ) TokenizationUseCase { return &tokenizationUseCase{ txManager: txManager, tokenizationRepo: tokenizationRepo, tokenRepo: tokenRepo, - dekRepo: dekRepo, - aeadManager: aeadManager, - keyManager: keyManager, hashService: hashService, - kekChain: kekChain, + keyring: kr, } } diff --git a/internal/tokenization/usecase/tokenization_usecase_test.go b/internal/tokenization/usecase/tokenization_usecase_test.go index 4b8751a..425a1eb 100644 --- a/internal/tokenization/usecase/tokenization_usecase_test.go +++ b/internal/tokenization/usecase/tokenization_usecase_test.go @@ -1,1855 +1,304 @@ -package usecase +package usecase_test import ( "context" - "errors" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoServiceMocks "github.com/allisson/secrets/internal/crypto/service/mocks" - databaseMocks "github.com/allisson/secrets/internal/database/mocks" + apperrors "github.com/allisson/secrets/internal/errors" + "github.com/allisson/secrets/internal/keyring" tokenizationDomain "github.com/allisson/secrets/internal/tokenization/domain" - tokenizationTesting "github.com/allisson/secrets/internal/tokenization/testing" - tokenizationMocks "github.com/allisson/secrets/internal/tokenization/usecase/mocks" + "github.com/allisson/secrets/internal/tokenization/usecase" + "github.com/allisson/secrets/internal/tokenization/usecase/mocks" ) -// TestTokenizationUseCase_Tokenize tests the Tokenize method. -func TestTokenizationUseCase_Tokenize(t *testing.T) { - ctx := context.Background() +func newTokenizationUseCase( + t *testing.T, +) ( + usecase.TokenizationUseCase, + *keyring.Fake, + *mocks.MockTokenizationKeyRepository, + *mocks.MockTokenRepository, + *mocks.MockHashService, +) { + t.Helper() + fake := keyring.NewFake() + keyRepo := mocks.NewMockTokenizationKeyRepository(t) + tokenRepo := mocks.NewMockTokenRepository(t) + hashSvc := mocks.NewMockHashService(t) + uc := usecase.NewTokenizationUseCase(noopTxManager{}, keyRepo, tokenRepo, hashSvc, fake) + return uc, fake, keyRepo, tokenRepo, hashSvc +} - t.Run("Success_NonDeterministicMode", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) +// allocateDekForTest seeds the keyring Fake with a DekID and returns it, +// mimicking a tokenization key previously created via uc.Create. +func allocateDekForTest(t *testing.T, fake *keyring.Fake) uuid.UUID { + t.Helper() + handle, err := fake.AllocateDek(context.Background(), keyring.AESGCM) + require.NoError(t, err) + return handle.DekID +} - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() +func TestTokenizationUseCase_Tokenize(t *testing.T) { + t.Parallel() + ctx := context.Background() - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dekID := uuid.Must(uuid.NewV7()) - tokenizationKeyID := uuid.Must(uuid.NewV7()) - salt := []byte("test-salt-32-bytes-long-12345678") + t.Run("Success_NonDeterministic", func(t *testing.T) { + t.Parallel() + uc, fake, keyRepo, tokenRepo, _ := newTokenizationUseCase(t) + dekID := allocateDekForTest(t, fake) - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: tokenizationKeyID, - DekID: dekID, - Name: "test-key", + key := &tokenizationDomain.TokenizationKey{ + ID: uuid.New(), + Name: "k", + Version: 1, FormatType: tokenizationDomain.FormatUUID, IsDeterministic: false, - Version: 1, - Salt: salt, - } - - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - dekKey := make([]byte, 32) - plaintext := []byte("test-value") - ciphertext := []byte("encrypted-value") - nonce := []byte("test-nonce") - metadata := map[string]any{"last4": "alue"} - expiresAt := time.Now().UTC().Add(24 * time.Hour) - - // Create mock cipher - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(tokenizationKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, activeKek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Encrypt(plaintext, mock.Anything). - Return(ciphertext, nonce, nil). - Once() - - mockTokenRepo.EXPECT(). - Create(ctx, mock.MatchedBy(func(token *tokenizationDomain.Token) bool { - return token.TokenizationKeyID == tokenizationKeyID && - len(token.Token) > 0 && - token.ValueHash == nil && - string(token.Ciphertext) == string(ciphertext) && - string(token.Nonce) == string(nonce) && - token.ExpiresAt.Equal(expiresAt) - })). - Return(nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - token, err := uc.Tokenize(ctx, "test-key", plaintext, metadata, &expiresAt) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, token) - assert.Equal(t, tokenizationKeyID, token.TokenizationKeyID) - assert.NotEmpty(t, token.Token) - assert.Nil(t, token.ValueHash) - assert.Equal(t, ciphertext, token.Ciphertext) - assert.Equal(t, nonce, token.Nonce) - assert.Equal(t, metadata, token.Metadata) - assert.Equal(t, expiresAt, *token.ExpiresAt) - }) - - t.Run("Success_DeterministicMode_NewToken", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dekID := uuid.Must(uuid.NewV7()) - tokenizationKeyID := uuid.Must(uuid.NewV7()) - salt := []byte("deterministic-salt-32-bytes-long") - - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: tokenizationKeyID, DekID: dekID, - Name: "test-key", - FormatType: tokenizationDomain.FormatLuhnPreserving, - IsDeterministic: true, - Version: 1, - Salt: salt, } - - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - dekKey := make([]byte, 32) - plaintext := []byte("4111111111111111") - valueHash := "hash-of-plaintext" - ciphertext := []byte("encrypted-value") - nonce := []byte("test-nonce") - - // Create mock cipher - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(tokenizationKey, nil). - Once() - - mockHashService.EXPECT(). - Hash(plaintext, salt). - Return(valueHash). - Once() - - mockTokenRepo.EXPECT(). - GetByValueHash(ctx, tokenizationKeyID, valueHash). - Return(nil, tokenizationDomain.ErrTokenNotFound). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, activeKek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Encrypt(plaintext, mock.Anything). - Return(ciphertext, nonce, nil). - Once() - - mockHashService.EXPECT(). - Hash(plaintext, salt). - Return(valueHash). - Once() - - mockTokenRepo.EXPECT(). - Create(ctx, mock.MatchedBy(func(token *tokenizationDomain.Token) bool { - return token.TokenizationKeyID == tokenizationKeyID && - len(token.Token) > 0 && - token.ValueHash != nil && - *token.ValueHash == valueHash && - string(token.Ciphertext) == string(ciphertext) && - string(token.Nonce) == string(nonce) - })). - Return(nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - token, err := uc.Tokenize(ctx, "test-key", plaintext, nil, nil) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, token) - assert.Equal(t, tokenizationKeyID, token.TokenizationKeyID) - assert.NotEmpty(t, token.Token) - assert.NotNil(t, token.ValueHash) - assert.Equal(t, valueHash, *token.ValueHash) - assert.Equal(t, ciphertext, token.Ciphertext) - assert.Equal(t, nonce, token.Nonce) + keyRepo.EXPECT().GetByName(ctx, "k").Return(key, nil) + tokenRepo.EXPECT().Create(ctx, mock.MatchedBy(func(tok *tokenizationDomain.Token) bool { + return tok.TokenizationKeyID == key.ID && + len(tok.Token) > 0 && + len(tok.Ciphertext) > 0 && + tok.ValueHash == nil + })).Return(nil) + + got, err := uc.Tokenize(ctx, "k", []byte("payload"), nil, nil) + require.NoError(t, err) + assert.NotEmpty(t, got.Token) }) - t.Run("Success_DeterministicMode_ExistingValidToken", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - tokenizationKeyID := uuid.Must(uuid.NewV7()) - salt := []byte("deterministic-salt-32-bytes-long") - plaintext := []byte("test-value") - valueHash := "hash-of-plaintext" - existingTokenValue := "existing-token-123" + t.Run("Success_Deterministic_ReturnsExistingValidToken", func(t *testing.T) { + t.Parallel() + uc, fake, keyRepo, tokenRepo, hashSvc := newTokenizationUseCase(t) + dekID := allocateDekForTest(t, fake) - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: tokenizationKeyID, - DekID: uuid.Must(uuid.NewV7()), - Name: "test-key", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: true, + key := &tokenizationDomain.TokenizationKey{ + ID: uuid.New(), + Name: "k", Version: 1, - Salt: salt, - } - - existingToken := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: tokenizationKeyID, - Token: existingTokenValue, - ValueHash: &valueHash, - Ciphertext: []byte("existing-ciphertext"), - Nonce: []byte("existing-nonce"), - CreatedAt: time.Now().UTC().Add(-1 * time.Hour), - ExpiresAt: nil, // No expiration - RevokedAt: nil, // Not revoked - } - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(tokenizationKey, nil). - Once() - - mockHashService.EXPECT(). - Hash(plaintext, salt). - Return(valueHash). - Once() - - mockTokenRepo.EXPECT(). - GetByValueHash(ctx, tokenizationKeyID, valueHash). - Return(existingToken, nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - token, err := uc.Tokenize(ctx, "test-key", plaintext, nil, nil) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, token) - assert.Equal(t, existingToken.ID, token.ID) - assert.Equal(t, existingTokenValue, token.Token) - assert.Equal(t, valueHash, *token.ValueHash) - }) - - t.Run("Success_DeterministicMode_ExpiredTokenCreatesNew", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dekID := uuid.Must(uuid.NewV7()) - tokenizationKeyID := uuid.Must(uuid.NewV7()) - salt := []byte("deterministic-salt-32-bytes-long") - plaintext := []byte("test-value") - valueHash := "hash-of-plaintext" - expiredTime := time.Now().UTC().Add(-1 * time.Hour) - - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: tokenizationKeyID, - DekID: dekID, - Name: "test-key", FormatType: tokenizationDomain.FormatUUID, IsDeterministic: true, - Version: 1, - Salt: salt, + Salt: []byte("salt"), + DekID: dekID, } - - expiredToken := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: tokenizationKeyID, - Token: "expired-token", - ValueHash: &valueHash, - Ciphertext: []byte("old-ciphertext"), - Nonce: []byte("old-nonce"), - CreatedAt: time.Now().UTC().Add(-2 * time.Hour), - ExpiresAt: &expiredTime, // Expired - RevokedAt: nil, + existing := &tokenizationDomain.Token{ + ID: uuid.New(), + TokenizationKeyID: key.ID, + Token: "existing-token", } - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - dekKey := make([]byte, 32) - ciphertext := []byte("new-encrypted-value") - nonce := []byte("new-nonce") - - // Create mock cipher - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(tokenizationKey, nil). - Once() - - mockHashService.EXPECT(). - Hash(plaintext, salt). - Return(valueHash). - Once() - - mockTokenRepo.EXPECT(). - GetByValueHash(ctx, tokenizationKeyID, valueHash). - Return(expiredToken, nil). - Once() + keyRepo.EXPECT().GetByName(ctx, "k").Return(key, nil) + hashSvc.EXPECT().Hash([]byte("payload"), []byte("salt")).Return("hash-value") + tokenRepo.EXPECT(). + GetByValueHash(ctx, key.ID, "hash-value"). + Return(existing, nil) - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, activeKek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Encrypt(plaintext, mock.Anything). - Return(ciphertext, nonce, nil). - Once() - - mockHashService.EXPECT(). - Hash(plaintext, salt). - Return(valueHash). - Once() - - mockTokenRepo.EXPECT(). - Create(ctx, mock.MatchedBy(func(token *tokenizationDomain.Token) bool { - return token.TokenizationKeyID == tokenizationKeyID && - len(token.Token) > 0 && - token.ValueHash != nil && - *token.ValueHash == valueHash && - string(token.Ciphertext) == string(ciphertext) - })). - Return(nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - token, err := uc.Tokenize(ctx, "test-key", plaintext, nil, nil) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, token) - assert.NotEqual(t, expiredToken.ID, token.ID) // Should be a new token - assert.NotEqual(t, "expired-token", token.Token) + got, err := uc.Tokenize(ctx, "k", []byte("payload"), nil, nil) + require.NoError(t, err) + assert.Equal(t, "existing-token", got.Token) }) - t.Run("Error_TokenizationKeyNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "nonexistent-key"). - Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - token, err := uc.Tokenize(ctx, "nonexistent-key", []byte("test"), nil, nil) - - // Assert - assert.Nil(t, token) - assert.Error(t, err) - assert.True(t, errors.Is(err, tokenizationDomain.ErrTokenizationKeyNotFound)) - assert.Contains(t, err.Error(), "failed to get tokenization key by name") + t.Run("Error_PlaintextEmpty", func(t *testing.T) { + t.Parallel() + uc, _, _, _, _ := newTokenizationUseCase(t) + _, err := uc.Tokenize(ctx, "k", nil, nil, nil) + assert.ErrorIs(t, err, tokenizationDomain.ErrPlaintextEmpty) }) - t.Run("Error_DekNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - dekID := uuid.Must(uuid.NewV7()) - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: uuid.Must(uuid.NewV7()), - DekID: dekID, - Name: "test-key", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: false, - Version: 1, - } - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(tokenizationKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(nil, cryptoDomain.ErrDekNotFound). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - token, err := uc.Tokenize(ctx, "test-key", []byte("test"), nil, nil) - - // Assert - assert.Nil(t, token) - assert.Error(t, err) - assert.True(t, errors.Is(err, cryptoDomain.ErrDekNotFound)) - assert.Contains(t, err.Error(), "failed to get DEK") + t.Run("Error_PlaintextTooLarge", func(t *testing.T) { + t.Parallel() + uc, _, _, _, _ := newTokenizationUseCase(t) + big := make([]byte, tokenizationDomain.MaxPlaintextSize+1) + _, err := uc.Tokenize(ctx, "k", big, nil, nil) + assert.ErrorIs(t, err, tokenizationDomain.ErrPlaintextTooLarge) }) - t.Run("Error_KekNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - dekID := uuid.Must(uuid.NewV7()) - nonexistentKekID := uuid.Must(uuid.NewV7()) // KEK not in chain - - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: uuid.Must(uuid.NewV7()), - DekID: dekID, - Name: "test-key", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: false, - Version: 1, - } - - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: nonexistentKekID, // References KEK not in chain - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(tokenizationKey, nil). - Once() + t.Run("Error_KeyNotFound", func(t *testing.T) { + t.Parallel() + uc, _, keyRepo, _, _ := newTokenizationUseCase(t) + keyRepo.EXPECT(). + GetByName(ctx, "missing"). + Return(nil, tokenizationDomain.ErrTokenizationKeyNotFound) - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - token, err := uc.Tokenize(ctx, "test-key", []byte("test"), nil, nil) - - // Assert - assert.Nil(t, token) - assert.Error(t, err) - assert.True(t, errors.Is(err, cryptoDomain.ErrKekNotFound)) - assert.Contains(t, err.Error(), "failed to get KEK") + _, err := uc.Tokenize(ctx, "missing", []byte("x"), nil, nil) + assert.ErrorIs(t, err, tokenizationDomain.ErrTokenizationKeyNotFound) }) - t.Run("Error_EncryptionFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dekID := uuid.Must(uuid.NewV7()) - - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: uuid.Must(uuid.NewV7()), - DekID: dekID, - Name: "test-key", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: false, - Version: 1, - } - - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - dekKey := make([]byte, 32) - plaintext := []byte("test-value") - - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - encryptionError := errors.New("encryption failed") + t.Run("Error_KeyringEncryptFails", func(t *testing.T) { + t.Parallel() + uc, fake, keyRepo, _, _ := newTokenizationUseCase(t) + dekID := allocateDekForTest(t, fake) + fake.FailEncrypt = apperrors.New("boom") - // Setup expectations - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(tokenizationKey, nil). - Once() + keyRepo.EXPECT().GetByName(ctx, "k").Return(&tokenizationDomain.TokenizationKey{ + ID: uuid.New(), + FormatType: tokenizationDomain.FormatUUID, + DekID: dekID, + }, nil) - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, activeKek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Encrypt(plaintext, mock.Anything). - Return(nil, nil, encryptionError). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - token, err := uc.Tokenize(ctx, "test-key", plaintext, nil, nil) - - // Assert - assert.Nil(t, token) + _, err := uc.Tokenize(ctx, "k", []byte("payload"), nil, nil) assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to encrypt plaintext") }) } -// TestTokenizationUseCase_Detokenize tests the Detokenize method. func TestTokenizationUseCase_Detokenize(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_DetokenizeValid", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) + t.Run("Success_RoundTrip", func(t *testing.T) { + t.Parallel() + uc, fake, keyRepo, tokenRepo, _ := newTokenizationUseCase(t) - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dekID := uuid.Must(uuid.NewV7()) - tokenizationKeyID := uuid.Must(uuid.NewV7()) - tokenValue := "test-token-123" - plaintext := []byte("original-value") - ciphertext := []byte("encrypted-value") - nonce := []byte("test-nonce") - metadata := map[string]any{"last4": "alue"} + // Encrypt via the fake to get matching ciphertext/nonce/dekID. + handle, err := fake.AllocateDek(ctx, keyring.AESGCM) + require.NoError(t, err) + plaintext := []byte("4111111111111111") + ciphertext, nonce, err := fake.EncryptWith(ctx, handle, plaintext, nil) + require.NoError(t, err) - tokenRecord := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: tokenizationKeyID, - Token: tokenValue, + key := &tokenizationDomain.TokenizationKey{ + ID: uuid.New(), + Name: "cards", + DekID: handle.DekID, + } + tokenRec := &tokenizationDomain.Token{ + TokenizationKeyID: key.ID, + Token: "tok", Ciphertext: ciphertext, Nonce: nonce, - Metadata: metadata, CreatedAt: time.Now().UTC(), - ExpiresAt: nil, - RevokedAt: nil, - } - - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: tokenizationKeyID, - DekID: dekID, - Name: "test-key", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: false, - Version: 1, - } - - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), } - dekKey := make([]byte, 32) - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, tokenValue). - Return(tokenRecord, nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - Get(ctx, tokenizationKeyID). - Return(tokenizationKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, activeKek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() + tokenRepo.EXPECT().GetByToken(ctx, "tok").Return(tokenRec, nil) + keyRepo.EXPECT().Get(ctx, key.ID).Return(key, nil) - mockCipher.EXPECT(). - Decrypt(ciphertext, nonce, mock.Anything). - Return(plaintext, nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - resultPlaintext, resultMetadata, err := uc.Detokenize(ctx, tokenValue) - - // Assert - assert.NoError(t, err) - assert.Equal(t, plaintext, resultPlaintext) - assert.Equal(t, metadata, resultMetadata) - }) - - t.Run("Error_TokenNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, "nonexistent-token"). - Return(nil, tokenizationDomain.ErrTokenNotFound). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - plaintext, metadata, err := uc.Detokenize(ctx, "nonexistent-token") - - // Assert - assert.Nil(t, plaintext) - assert.Nil(t, metadata) - assert.True(t, errors.Is(err, tokenizationDomain.ErrTokenNotFound)) - assert.Contains(t, err.Error(), "failed to get token") + got, _, err := uc.Detokenize(ctx, "tok") + require.NoError(t, err) + assert.Equal(t, plaintext, got) }) t.Run("Error_TokenExpired", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - expiredTime := time.Now().UTC().Add(-1 * time.Hour) - tokenValue := "expired-token" - - tokenRecord := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: uuid.Must(uuid.NewV7()), - Token: tokenValue, - Ciphertext: []byte("ciphertext"), - Nonce: []byte("nonce"), - CreatedAt: time.Now().UTC().Add(-2 * time.Hour), - ExpiresAt: &expiredTime, - RevokedAt: nil, - } - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, tokenValue). - Return(tokenRecord, nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - plaintext, metadata, err := uc.Detokenize(ctx, tokenValue) - - // Assert - assert.Nil(t, plaintext) - assert.Nil(t, metadata) - assert.Equal(t, tokenizationDomain.ErrTokenExpired, err) + t.Parallel() + uc, _, _, tokenRepo, _ := newTokenizationUseCase(t) + past := time.Now().Add(-time.Hour) + tokenRepo.EXPECT().GetByToken(ctx, "tok").Return(&tokenizationDomain.Token{ + Token: "tok", + ExpiresAt: &past, + }, nil) + + _, _, err := uc.Detokenize(ctx, "tok") + assert.ErrorIs(t, err, tokenizationDomain.ErrTokenExpired) }) t.Run("Error_TokenRevoked", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - revokedTime := time.Now().UTC().Add(-30 * time.Minute) - tokenValue := "revoked-token" - - tokenRecord := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: uuid.Must(uuid.NewV7()), - Token: tokenValue, - Ciphertext: []byte("ciphertext"), - Nonce: []byte("nonce"), - CreatedAt: time.Now().UTC().Add(-1 * time.Hour), - ExpiresAt: nil, - RevokedAt: &revokedTime, - } - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, tokenValue). - Return(tokenRecord, nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - plaintext, metadata, err := uc.Detokenize(ctx, tokenValue) - - // Assert - assert.Nil(t, plaintext) - assert.Nil(t, metadata) - assert.Equal(t, tokenizationDomain.ErrTokenRevoked, err) + t.Parallel() + uc, _, _, tokenRepo, _ := newTokenizationUseCase(t) + now := time.Now() + tokenRepo.EXPECT().GetByToken(ctx, "tok").Return(&tokenizationDomain.Token{ + Token: "tok", + RevokedAt: &now, + }, nil) + + _, _, err := uc.Detokenize(ctx, "tok") + assert.ErrorIs(t, err, tokenizationDomain.ErrTokenRevoked) }) - t.Run("Error_DecryptionFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dekID := uuid.Must(uuid.NewV7()) - tokenizationKeyID := uuid.Must(uuid.NewV7()) - tokenValue := "test-token" + t.Run("Error_DecryptFails", func(t *testing.T) { + t.Parallel() + uc, fake, keyRepo, tokenRepo, _ := newTokenizationUseCase(t) + dekID := allocateDekForTest(t, fake) + fake.FailDecrypt = apperrors.New("AEAD tag mismatch") - tokenRecord := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: tokenizationKeyID, - Token: tokenValue, - Ciphertext: []byte("corrupted-ciphertext"), + tokenRepo.EXPECT().GetByToken(ctx, "tok").Return(&tokenizationDomain.Token{ + Token: "tok", + TokenizationKeyID: uuid.New(), + Ciphertext: []byte("ct"), Nonce: []byte("nonce"), - CreatedAt: time.Now().UTC(), - ExpiresAt: nil, - RevokedAt: nil, - } - - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: tokenizationKeyID, - DekID: dekID, - Name: "test-key", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: false, - Version: 1, - } - - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - } - - dekKey := make([]byte, 32) - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - decryptionError := errors.New("decryption failed") - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, tokenValue). - Return(tokenRecord, nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - Get(ctx, tokenizationKeyID). - Return(tokenizationKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() + }, nil) + keyRepo.EXPECT().Get(ctx, mock.Anything).Return(&tokenizationDomain.TokenizationKey{ + DekID: dekID, + }, nil) - mockKeyManager.EXPECT(). - DecryptDek(dek, activeKek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Decrypt(tokenRecord.Ciphertext, tokenRecord.Nonce, mock.Anything). - Return(nil, decryptionError). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - plaintext, metadata, err := uc.Detokenize(ctx, tokenValue) - - // Assert - assert.Nil(t, plaintext) - assert.Nil(t, metadata) - assert.True(t, errors.Is(err, cryptoDomain.ErrDecryptionFailed)) - assert.Contains(t, err.Error(), "failed to decrypt token ciphertext") + _, _, err := uc.Detokenize(ctx, "tok") + assert.ErrorIs(t, err, cryptoDomain.ErrDecryptionFailed) }) } -// TestTokenizationUseCase_Validate tests the Validate method. func TestTokenizationUseCase_Validate(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_ValidToken", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - tokenValue := "valid-token" - tokenRecord := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: uuid.Must(uuid.NewV7()), - Token: tokenValue, - Ciphertext: []byte("ciphertext"), - Nonce: []byte("nonce"), - CreatedAt: time.Now().UTC(), - ExpiresAt: nil, - RevokedAt: nil, - } - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, tokenValue). - Return(tokenRecord, nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - isValid, err := uc.Validate(ctx, tokenValue) - - // Assert - assert.NoError(t, err) - assert.True(t, isValid) + t.Run("Valid", func(t *testing.T) { + t.Parallel() + uc, _, _, tokenRepo, _ := newTokenizationUseCase(t) + tokenRepo.EXPECT().GetByToken(ctx, "tok").Return(&tokenizationDomain.Token{ + Token: "tok", + CreatedAt: time.Now(), + }, nil) + + ok, err := uc.Validate(ctx, "tok") + require.NoError(t, err) + assert.True(t, ok) }) - t.Run("Success_ExpiredToken", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - expiredTime := time.Now().UTC().Add(-1 * time.Hour) - tokenValue := "expired-token" - tokenRecord := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: uuid.Must(uuid.NewV7()), - Token: tokenValue, - Ciphertext: []byte("ciphertext"), - Nonce: []byte("nonce"), - CreatedAt: time.Now().UTC().Add(-2 * time.Hour), - ExpiresAt: &expiredTime, - RevokedAt: nil, - } - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, tokenValue). - Return(tokenRecord, nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - isValid, err := uc.Validate(ctx, tokenValue) + t.Run("NotFound_ReturnsFalseNoError", func(t *testing.T) { + t.Parallel() + uc, _, _, tokenRepo, _ := newTokenizationUseCase(t) + tokenRepo.EXPECT().GetByToken(ctx, "tok").Return(nil, tokenizationDomain.ErrTokenNotFound) - // Assert + ok, err := uc.Validate(ctx, "tok") assert.NoError(t, err) - assert.False(t, isValid) - }) - - t.Run("Success_TokenNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, "nonexistent-token"). - Return(nil, tokenizationDomain.ErrTokenNotFound). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - isValid, err := uc.Validate(ctx, "nonexistent-token") - - // Assert - assert.NoError(t, err) - assert.False(t, isValid) - }) - - t.Run("Error_RepositoryError", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - dbError := errors.New("database error") - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, "test-token"). - Return(nil, dbError). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - isValid, err := uc.Validate(ctx, "test-token") - - // Assert - assert.False(t, isValid) - assert.Error(t, err) - assert.True(t, errors.Is(err, dbError)) - assert.Contains(t, err.Error(), "failed to validate token") + assert.False(t, ok) }) } -// TestTokenizationUseCase_Revoke tests the Revoke method. func TestTokenizationUseCase_Revoke(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_RevokeToken", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - tokenValue := "token-to-revoke" - tokenRecord := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: uuid.Must(uuid.NewV7()), - Token: tokenValue, - Ciphertext: []byte("ciphertext"), - Nonce: []byte("nonce"), - CreatedAt: time.Now().UTC(), - ExpiresAt: nil, - RevokedAt: nil, - } - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, tokenValue). - Return(tokenRecord, nil). - Once() - - mockTokenRepo.EXPECT(). - Revoke(ctx, tokenValue). - Return(nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - err := uc.Revoke(ctx, tokenValue) - - // Assert - assert.NoError(t, err) - }) - - t.Run("Error_TokenNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, "nonexistent-token"). - Return(nil, tokenizationDomain.ErrTokenNotFound). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - err := uc.Revoke(ctx, "nonexistent-token") - - // Assert - assert.Error(t, err) - assert.True(t, errors.Is(err, tokenizationDomain.ErrTokenNotFound)) - assert.Contains(t, err.Error(), "failed to get token for revocation") - }) - - t.Run("Error_RevokeFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - tokenValue := "test-token" - tokenRecord := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: uuid.Must(uuid.NewV7()), - Token: tokenValue, - Ciphertext: []byte("ciphertext"), - Nonce: []byte("nonce"), - CreatedAt: time.Now().UTC(), - ExpiresAt: nil, - RevokedAt: nil, - } - - dbError := errors.New("database error") + uc, _, _, tokenRepo, _ := newTokenizationUseCase(t) + tokenRepo.EXPECT().GetByToken(ctx, "tok").Return(&tokenizationDomain.Token{Token: "tok"}, nil) + tokenRepo.EXPECT().Revoke(ctx, "tok").Return(nil) - // Setup expectations - mockTokenRepo.EXPECT(). - GetByToken(ctx, tokenValue). - Return(tokenRecord, nil). - Once() - - mockTokenRepo.EXPECT(). - Revoke(ctx, tokenValue). - Return(dbError). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - err := uc.Revoke(ctx, tokenValue) - - // Assert - assert.Error(t, err) - assert.True(t, errors.Is(err, dbError)) - assert.Contains(t, err.Error(), "failed to revoke token") - }) + assert.NoError(t, uc.Revoke(ctx, "tok")) } -// TestTokenizationUseCase_CleanupExpired tests the CleanupExpired method. func TestTokenizationUseCase_CleanupExpired(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_DryRunMode", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - // Setup expectations - mockTokenRepo.EXPECT(). - CountExpired(ctx, mock.MatchedBy(func(cutoff time.Time) bool { - // Verify cutoff is approximately 7 days ago - expectedCutoff := time.Now().UTC().AddDate(0, 0, -7) - // Allow 2 second variance for test execution time - return cutoff.After(expectedCutoff.Add(-2*time.Second)) && - cutoff.Before(expectedCutoff.Add(2*time.Second)) - })). - Return(int64(42), nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - count, err := uc.CleanupExpired(ctx, 7, true) - - // Assert - assert.NoError(t, err) - assert.Equal(t, int64(42), count) - }) - - t.Run("Success_DeleteMode", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - // Setup expectations - mockTokenRepo.EXPECT(). - DeleteExpired(ctx, mock.MatchedBy(func(cutoff time.Time) bool { - // Verify cutoff is approximately 30 days ago - expectedCutoff := time.Now().UTC().AddDate(0, 0, -30) - // Allow 2 second variance for test execution time - return cutoff.After(expectedCutoff.Add(-2*time.Second)) && - cutoff.Before(expectedCutoff.Add(2*time.Second)) - })). - Return(int64(100), nil). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - count, err := uc.CleanupExpired(ctx, 30, false) - - // Assert - assert.NoError(t, err) - assert.Equal(t, int64(100), count) - }) - t.Run("Error_NegativeDays", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - count, err := uc.CleanupExpired(ctx, -1, false) - - // Assert - assert.Equal(t, int64(0), count) + t.Parallel() + uc, _, _, _, _ := newTokenizationUseCase(t) + _, err := uc.CleanupExpired(ctx, -1, false) assert.Error(t, err) - assert.Contains(t, err.Error(), "days must be non-negative") - }) - - t.Run("Error_RepositoryError", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - dbError := errors.New("database error") - - // Setup expectations - mockTokenRepo.EXPECT(). - DeleteExpired(ctx, mock.AnythingOfType("time.Time")). - Return(int64(0), dbError). - Once() - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - count, err := uc.CleanupExpired(ctx, 7, false) - - // Assert - assert.Equal(t, int64(0), count) - assert.Equal(t, dbError, err) }) -} - -// TestTokenizationUseCase_TokenizeBatch tests the TokenizeBatch method. -func TestTokenizationUseCase_TokenizeBatch(t *testing.T) { - ctx := context.Background() - - t.Run("Success", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dekID := uuid.Must(uuid.NewV7()) - tokenizationKeyID := uuid.Must(uuid.NewV7()) - - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: tokenizationKeyID, - DekID: dekID, - Name: "test-key", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: false, - Version: 1, - } - - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - } - - dekKey := make([]byte, 32) - plaintexts := [][]byte{[]byte("value1"), []byte("value2")} - mockCipher := cryptoServiceMocks.NewMockAEAD(t) - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, f func(context.Context) error) { - _ = f(ctx) - }). - Return(nil). - Once() - - // Expectations for each item in batch - for range plaintexts { - mockTokenizationKeyRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(tokenizationKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, activeKek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Encrypt(mock.Anything, mock.Anything). - Return([]byte("ciphertext"), []byte("nonce"), nil). - Once() - mockTokenRepo.EXPECT(). - Create(ctx, mock.Anything). - Return(nil). - Once() - } - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - tokens, err := uc.TokenizeBatch(ctx, "test-key", plaintexts, nil, nil) + t.Run("Success_DryRun", func(t *testing.T) { + t.Parallel() + uc, _, _, tokenRepo, _ := newTokenizationUseCase(t) + tokenRepo.EXPECT().CountExpired(ctx, mock.Anything).Return(int64(7), nil) - // Assert - assert.NoError(t, err) - assert.Len(t, tokens, 2) + n, err := uc.CleanupExpired(ctx, 30, true) + require.NoError(t, err) + assert.EqualValues(t, 7, n) }) -} - -// TestTokenizationUseCase_DetokenizeBatch tests the DetokenizeBatch method. -func TestTokenizationUseCase_DetokenizeBatch(t *testing.T) { - ctx := context.Background() - - t.Run("Success", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTokenizationKeyRepo := tokenizationMocks.NewMockTokenizationKeyRepository(t) - mockTokenRepo := tokenizationMocks.NewMockTokenRepository(t) - mockDekRepo := tokenizationMocks.NewMockDekRepository(t) - mockAEADManager := cryptoServiceMocks.NewMockAEADManager(t) - mockKeyManager := cryptoServiceMocks.NewMockKeyManager(t) - mockHashService := tokenizationMocks.NewMockHashService(t) - - // Create test data - masterKey := tokenizationTesting.CreateMasterKey() - kekChain := tokenizationTesting.CreateKekChain(masterKey) - defer kekChain.Close() - - activeKek := tokenizationTesting.GetActiveKek(kekChain) - dekID := uuid.Must(uuid.NewV7()) - tokenizationKeyID := uuid.Must(uuid.NewV7()) - tokens := []string{"token1", "token2"} - - tokenRecord := &tokenizationDomain.Token{ - ID: uuid.Must(uuid.NewV7()), - TokenizationKeyID: tokenizationKeyID, - Token: "token", - Ciphertext: []byte("ciphertext"), - Nonce: []byte("nonce"), - CreatedAt: time.Now().UTC(), - } - - tokenizationKey := &tokenizationDomain.TokenizationKey{ - ID: tokenizationKeyID, - DekID: dekID, - Name: "test-key", - FormatType: tokenizationDomain.FormatUUID, - IsDeterministic: false, - Version: 1, - } - - dek := &cryptoDomain.Dek{ - ID: dekID, - KekID: activeKek.ID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("dek-nonce"), - } - dekKey := make([]byte, 32) - mockCipher := cryptoServiceMocks.NewMockAEAD(t) + t.Run("Success_Delete", func(t *testing.T) { + t.Parallel() + uc, _, _, tokenRepo, _ := newTokenizationUseCase(t) + tokenRepo.EXPECT().DeleteExpired(ctx, mock.Anything).Return(int64(4), nil) - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, f func(context.Context) error) { - _ = f(ctx) - }). - Return(nil). - Once() - - // Expectations for each item in batch - for range tokens { - mockTokenRepo.EXPECT(). - GetByToken(ctx, mock.Anything). - Return(tokenRecord, nil). - Once() - - mockTokenizationKeyRepo.EXPECT(). - Get(ctx, tokenizationKeyID). - Return(tokenizationKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dekID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, activeKek). - Return(dekKey, nil). - Once() - - mockAEADManager.EXPECT(). - CreateCipher(dekKey, cryptoDomain.AESGCM). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Decrypt(mock.Anything, mock.Anything, mock.Anything). - Return([]byte("plaintext"), nil). - Once() - } - - // Create use case - uc := NewTokenizationUseCase( - mockTxManager, - mockTokenizationKeyRepo, - mockTokenRepo, - mockDekRepo, - mockAEADManager, - mockKeyManager, - mockHashService, - kekChain, - ) - - // Execute - plaintexts, metadatas, err := uc.DetokenizeBatch(ctx, tokens) - - // Assert - assert.NoError(t, err) - assert.Len(t, plaintexts, 2) - assert.Len(t, metadatas, 2) + n, err := uc.CleanupExpired(ctx, 30, false) + require.NoError(t, err) + assert.EqualValues(t, 4, n) }) } diff --git a/internal/transit/domain/repository.go b/internal/transit/domain/repository.go index 9d19f97..e253b39 100644 --- a/internal/transit/domain/repository.go +++ b/internal/transit/domain/repository.go @@ -4,20 +4,9 @@ import ( "context" "time" - "github.com/google/uuid" - cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" ) -// DekRepository defines the interface for DEK persistence operations within the transit module. -type DekRepository interface { - // Create stores a new DEK in the repository using transaction support from context. - Create(ctx context.Context, dek *cryptoDomain.Dek) error - - // Get retrieves a DEK by its ID. Returns ErrDekNotFound if not found. - Get(ctx context.Context, dekID uuid.UUID) (*cryptoDomain.Dek, error) -} - // TransitKeyRepository defines the interface for transit key persistence. type TransitKeyRepository interface { // Create stores a new transit key in the repository using transaction support from context. diff --git a/internal/transit/usecase/interface.go b/internal/transit/usecase/interface.go index d4f8225..89ac46a 100644 --- a/internal/transit/usecase/interface.go +++ b/internal/transit/usecase/interface.go @@ -9,9 +9,8 @@ import ( transitDomain "github.com/allisson/secrets/internal/transit/domain" ) -// Re-export repository interfaces for convenience and backward compatibility if needed. -// However, the canonical location is now internal/transit/domain/repository.go. -type DekRepository = transitDomain.DekRepository +// TransitKeyRepository is re-exported for convenience. The canonical location +// is internal/transit/domain/repository.go. type TransitKeyRepository = transitDomain.TransitKeyRepository // TransitKeyUseCase defines the interface for transit encryption operations. diff --git a/internal/transit/usecase/mocks/mocks.go b/internal/transit/usecase/mocks/mocks.go index 35f49f9..b211199 100644 --- a/internal/transit/usecase/mocks/mocks.go +++ b/internal/transit/usecase/mocks/mocks.go @@ -8,164 +8,11 @@ import ( "context" "time" - "github.com/allisson/secrets/internal/crypto/domain" - domain0 "github.com/allisson/secrets/internal/transit/domain" - "github.com/google/uuid" + domain0 "github.com/allisson/secrets/internal/crypto/domain" + "github.com/allisson/secrets/internal/transit/domain" mock "github.com/stretchr/testify/mock" ) -// NewMockDekRepository creates a new instance of MockDekRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockDekRepository(t interface { - mock.TestingT - Cleanup(func()) -}) *MockDekRepository { - mock := &MockDekRepository{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockDekRepository is an autogenerated mock type for the DekRepository type -type MockDekRepository struct { - mock.Mock -} - -type MockDekRepository_Expecter struct { - mock *mock.Mock -} - -func (_m *MockDekRepository) EXPECT() *MockDekRepository_Expecter { - return &MockDekRepository_Expecter{mock: &_m.Mock} -} - -// Create provides a mock function for the type MockDekRepository -func (_mock *MockDekRepository) Create(ctx context.Context, dek *domain.Dek) error { - ret := _mock.Called(ctx, dek) - - if len(ret) == 0 { - panic("no return value specified for Create") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.Dek) error); ok { - r0 = returnFunc(ctx, dek) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockDekRepository_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' -type MockDekRepository_Create_Call struct { - *mock.Call -} - -// Create is a helper method to define mock.On call -// - ctx context.Context -// - dek *domain.Dek -func (_e *MockDekRepository_Expecter) Create(ctx interface{}, dek interface{}) *MockDekRepository_Create_Call { - return &MockDekRepository_Create_Call{Call: _e.mock.On("Create", ctx, dek)} -} - -func (_c *MockDekRepository_Create_Call) Run(run func(ctx context.Context, dek *domain.Dek)) *MockDekRepository_Create_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 *domain.Dek - if args[1] != nil { - arg1 = args[1].(*domain.Dek) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockDekRepository_Create_Call) Return(err error) *MockDekRepository_Create_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockDekRepository_Create_Call) RunAndReturn(run func(ctx context.Context, dek *domain.Dek) error) *MockDekRepository_Create_Call { - _c.Call.Return(run) - return _c -} - -// Get provides a mock function for the type MockDekRepository -func (_mock *MockDekRepository) Get(ctx context.Context, dekID uuid.UUID) (*domain.Dek, error) { - ret := _mock.Called(ctx, dekID) - - if len(ret) == 0 { - panic("no return value specified for Get") - } - - var r0 *domain.Dek - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*domain.Dek, error)); ok { - return returnFunc(ctx, dekID) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, uuid.UUID) *domain.Dek); ok { - r0 = returnFunc(ctx, dekID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain.Dek) - } - } - if returnFunc, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { - r1 = returnFunc(ctx, dekID) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockDekRepository_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' -type MockDekRepository_Get_Call struct { - *mock.Call -} - -// Get is a helper method to define mock.On call -// - ctx context.Context -// - dekID uuid.UUID -func (_e *MockDekRepository_Expecter) Get(ctx interface{}, dekID interface{}) *MockDekRepository_Get_Call { - return &MockDekRepository_Get_Call{Call: _e.mock.On("Get", ctx, dekID)} -} - -func (_c *MockDekRepository_Get_Call) Run(run func(ctx context.Context, dekID uuid.UUID)) *MockDekRepository_Get_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 uuid.UUID - if args[1] != nil { - arg1 = args[1].(uuid.UUID) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockDekRepository_Get_Call) Return(dek *domain.Dek, err error) *MockDekRepository_Get_Call { - _c.Call.Return(dek, err) - return _c -} - -func (_c *MockDekRepository_Get_Call) RunAndReturn(run func(ctx context.Context, dekID uuid.UUID) (*domain.Dek, error)) *MockDekRepository_Get_Call { - _c.Call.Return(run) - return _c -} - // NewMockTransitKeyRepository creates a new instance of MockTransitKeyRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockTransitKeyRepository(t interface { @@ -194,7 +41,7 @@ func (_m *MockTransitKeyRepository) EXPECT() *MockTransitKeyRepository_Expecter } // Create provides a mock function for the type MockTransitKeyRepository -func (_mock *MockTransitKeyRepository) Create(ctx context.Context, transitKey *domain0.TransitKey) error { +func (_mock *MockTransitKeyRepository) Create(ctx context.Context, transitKey *domain.TransitKey) error { ret := _mock.Called(ctx, transitKey) if len(ret) == 0 { @@ -202,7 +49,7 @@ func (_mock *MockTransitKeyRepository) Create(ctx context.Context, transitKey *d } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *domain0.TransitKey) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *domain.TransitKey) error); ok { r0 = returnFunc(ctx, transitKey) } else { r0 = ret.Error(0) @@ -217,20 +64,20 @@ type MockTransitKeyRepository_Create_Call struct { // Create is a helper method to define mock.On call // - ctx context.Context -// - transitKey *domain0.TransitKey +// - transitKey *domain.TransitKey func (_e *MockTransitKeyRepository_Expecter) Create(ctx interface{}, transitKey interface{}) *MockTransitKeyRepository_Create_Call { return &MockTransitKeyRepository_Create_Call{Call: _e.mock.On("Create", ctx, transitKey)} } -func (_c *MockTransitKeyRepository_Create_Call) Run(run func(ctx context.Context, transitKey *domain0.TransitKey)) *MockTransitKeyRepository_Create_Call { +func (_c *MockTransitKeyRepository_Create_Call) Run(run func(ctx context.Context, transitKey *domain.TransitKey)) *MockTransitKeyRepository_Create_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 *domain0.TransitKey + var arg1 *domain.TransitKey if args[1] != nil { - arg1 = args[1].(*domain0.TransitKey) + arg1 = args[1].(*domain.TransitKey) } run( arg0, @@ -245,7 +92,7 @@ func (_c *MockTransitKeyRepository_Create_Call) Return(err error) *MockTransitKe return _c } -func (_c *MockTransitKeyRepository_Create_Call) RunAndReturn(run func(ctx context.Context, transitKey *domain0.TransitKey) error) *MockTransitKeyRepository_Create_Call { +func (_c *MockTransitKeyRepository_Create_Call) RunAndReturn(run func(ctx context.Context, transitKey *domain.TransitKey) error) *MockTransitKeyRepository_Create_Call { _c.Call.Return(run) return _c } @@ -308,23 +155,23 @@ func (_c *MockTransitKeyRepository_Delete_Call) RunAndReturn(run func(ctx contex } // GetByName provides a mock function for the type MockTransitKeyRepository -func (_mock *MockTransitKeyRepository) GetByName(ctx context.Context, name string) (*domain0.TransitKey, error) { +func (_mock *MockTransitKeyRepository) GetByName(ctx context.Context, name string) (*domain.TransitKey, error) { ret := _mock.Called(ctx, name) if len(ret) == 0 { panic("no return value specified for GetByName") } - var r0 *domain0.TransitKey + var r0 *domain.TransitKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain0.TransitKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*domain.TransitKey, error)); ok { return returnFunc(ctx, name) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain0.TransitKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *domain.TransitKey); ok { r0 = returnFunc(ctx, name) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TransitKey) + r0 = ret.Get(0).(*domain.TransitKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { @@ -365,34 +212,34 @@ func (_c *MockTransitKeyRepository_GetByName_Call) Run(run func(ctx context.Cont return _c } -func (_c *MockTransitKeyRepository_GetByName_Call) Return(transitKey *domain0.TransitKey, err error) *MockTransitKeyRepository_GetByName_Call { +func (_c *MockTransitKeyRepository_GetByName_Call) Return(transitKey *domain.TransitKey, err error) *MockTransitKeyRepository_GetByName_Call { _c.Call.Return(transitKey, err) return _c } -func (_c *MockTransitKeyRepository_GetByName_Call) RunAndReturn(run func(ctx context.Context, name string) (*domain0.TransitKey, error)) *MockTransitKeyRepository_GetByName_Call { +func (_c *MockTransitKeyRepository_GetByName_Call) RunAndReturn(run func(ctx context.Context, name string) (*domain.TransitKey, error)) *MockTransitKeyRepository_GetByName_Call { _c.Call.Return(run) return _c } // GetByNameAndVersion provides a mock function for the type MockTransitKeyRepository -func (_mock *MockTransitKeyRepository) GetByNameAndVersion(ctx context.Context, name string, version uint) (*domain0.TransitKey, error) { +func (_mock *MockTransitKeyRepository) GetByNameAndVersion(ctx context.Context, name string, version uint) (*domain.TransitKey, error) { ret := _mock.Called(ctx, name, version) if len(ret) == 0 { panic("no return value specified for GetByNameAndVersion") } - var r0 *domain0.TransitKey + var r0 *domain.TransitKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain0.TransitKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain.TransitKey, error)); ok { return returnFunc(ctx, name, version) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain0.TransitKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain.TransitKey); ok { r0 = returnFunc(ctx, name, version) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TransitKey) + r0 = ret.Get(0).(*domain.TransitKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) error); ok { @@ -439,41 +286,41 @@ func (_c *MockTransitKeyRepository_GetByNameAndVersion_Call) Run(run func(ctx co return _c } -func (_c *MockTransitKeyRepository_GetByNameAndVersion_Call) Return(transitKey *domain0.TransitKey, err error) *MockTransitKeyRepository_GetByNameAndVersion_Call { +func (_c *MockTransitKeyRepository_GetByNameAndVersion_Call) Return(transitKey *domain.TransitKey, err error) *MockTransitKeyRepository_GetByNameAndVersion_Call { _c.Call.Return(transitKey, err) return _c } -func (_c *MockTransitKeyRepository_GetByNameAndVersion_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain0.TransitKey, error)) *MockTransitKeyRepository_GetByNameAndVersion_Call { +func (_c *MockTransitKeyRepository_GetByNameAndVersion_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain.TransitKey, error)) *MockTransitKeyRepository_GetByNameAndVersion_Call { _c.Call.Return(run) return _c } // GetTransitKey provides a mock function for the type MockTransitKeyRepository -func (_mock *MockTransitKeyRepository) GetTransitKey(ctx context.Context, name string, version uint) (*domain0.TransitKey, domain.Algorithm, error) { +func (_mock *MockTransitKeyRepository) GetTransitKey(ctx context.Context, name string, version uint) (*domain.TransitKey, domain0.Algorithm, error) { ret := _mock.Called(ctx, name, version) if len(ret) == 0 { panic("no return value specified for GetTransitKey") } - var r0 *domain0.TransitKey - var r1 domain.Algorithm + var r0 *domain.TransitKey + var r1 domain0.Algorithm var r2 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain0.TransitKey, domain.Algorithm, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain.TransitKey, domain0.Algorithm, error)); ok { return returnFunc(ctx, name, version) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain0.TransitKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain.TransitKey); ok { r0 = returnFunc(ctx, name, version) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TransitKey) + r0 = ret.Get(0).(*domain.TransitKey) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) domain.Algorithm); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) domain0.Algorithm); ok { r1 = returnFunc(ctx, name, version) } else { - r1 = ret.Get(1).(domain.Algorithm) + r1 = ret.Get(1).(domain0.Algorithm) } if returnFunc, ok := ret.Get(2).(func(context.Context, string, uint) error); ok { r2 = returnFunc(ctx, name, version) @@ -519,12 +366,12 @@ func (_c *MockTransitKeyRepository_GetTransitKey_Call) Run(run func(ctx context. return _c } -func (_c *MockTransitKeyRepository_GetTransitKey_Call) Return(transitKey *domain0.TransitKey, algorithm domain.Algorithm, err error) *MockTransitKeyRepository_GetTransitKey_Call { +func (_c *MockTransitKeyRepository_GetTransitKey_Call) Return(transitKey *domain.TransitKey, algorithm domain0.Algorithm, err error) *MockTransitKeyRepository_GetTransitKey_Call { _c.Call.Return(transitKey, algorithm, err) return _c } -func (_c *MockTransitKeyRepository_GetTransitKey_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain0.TransitKey, domain.Algorithm, error)) *MockTransitKeyRepository_GetTransitKey_Call { +func (_c *MockTransitKeyRepository_GetTransitKey_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain.TransitKey, domain0.Algorithm, error)) *MockTransitKeyRepository_GetTransitKey_Call { _c.Call.Return(run) return _c } @@ -602,23 +449,23 @@ func (_c *MockTransitKeyRepository_HardDelete_Call) RunAndReturn(run func(ctx co } // ListCursor provides a mock function for the type MockTransitKeyRepository -func (_mock *MockTransitKeyRepository) ListCursor(ctx context.Context, afterName *string, limit int) ([]*domain0.TransitKey, error) { +func (_mock *MockTransitKeyRepository) ListCursor(ctx context.Context, afterName *string, limit int) ([]*domain.TransitKey, error) { ret := _mock.Called(ctx, afterName, limit) if len(ret) == 0 { panic("no return value specified for ListCursor") } - var r0 []*domain0.TransitKey + var r0 []*domain.TransitKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain0.TransitKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain.TransitKey, error)); ok { return returnFunc(ctx, afterName, limit) } - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain0.TransitKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain.TransitKey); ok { r0 = returnFunc(ctx, afterName, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*domain0.TransitKey) + r0 = ret.Get(0).([]*domain.TransitKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, *string, int) error); ok { @@ -665,12 +512,12 @@ func (_c *MockTransitKeyRepository_ListCursor_Call) Run(run func(ctx context.Con return _c } -func (_c *MockTransitKeyRepository_ListCursor_Call) Return(transitKeys []*domain0.TransitKey, err error) *MockTransitKeyRepository_ListCursor_Call { +func (_c *MockTransitKeyRepository_ListCursor_Call) Return(transitKeys []*domain.TransitKey, err error) *MockTransitKeyRepository_ListCursor_Call { _c.Call.Return(transitKeys, err) return _c } -func (_c *MockTransitKeyRepository_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterName *string, limit int) ([]*domain0.TransitKey, error)) *MockTransitKeyRepository_ListCursor_Call { +func (_c *MockTransitKeyRepository_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterName *string, limit int) ([]*domain.TransitKey, error)) *MockTransitKeyRepository_ListCursor_Call { _c.Call.Return(run) return _c } @@ -703,26 +550,26 @@ func (_m *MockTransitKeyUseCase) EXPECT() *MockTransitKeyUseCase_Expecter { } // Create provides a mock function for the type MockTransitKeyUseCase -func (_mock *MockTransitKeyUseCase) Create(ctx context.Context, name string, alg domain.Algorithm) (*domain0.TransitKey, error) { +func (_mock *MockTransitKeyUseCase) Create(ctx context.Context, name string, alg domain0.Algorithm) (*domain.TransitKey, error) { ret := _mock.Called(ctx, name, alg) if len(ret) == 0 { panic("no return value specified for Create") } - var r0 *domain0.TransitKey + var r0 *domain.TransitKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain.Algorithm) (*domain0.TransitKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain0.Algorithm) (*domain.TransitKey, error)); ok { return returnFunc(ctx, name, alg) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain.Algorithm) *domain0.TransitKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain0.Algorithm) *domain.TransitKey); ok { r0 = returnFunc(ctx, name, alg) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TransitKey) + r0 = ret.Get(0).(*domain.TransitKey) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, domain.Algorithm) error); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, domain0.Algorithm) error); ok { r1 = returnFunc(ctx, name, alg) } else { r1 = ret.Error(1) @@ -738,12 +585,12 @@ type MockTransitKeyUseCase_Create_Call struct { // Create is a helper method to define mock.On call // - ctx context.Context // - name string -// - alg domain.Algorithm +// - alg domain0.Algorithm func (_e *MockTransitKeyUseCase_Expecter) Create(ctx interface{}, name interface{}, alg interface{}) *MockTransitKeyUseCase_Create_Call { return &MockTransitKeyUseCase_Create_Call{Call: _e.mock.On("Create", ctx, name, alg)} } -func (_c *MockTransitKeyUseCase_Create_Call) Run(run func(ctx context.Context, name string, alg domain.Algorithm)) *MockTransitKeyUseCase_Create_Call { +func (_c *MockTransitKeyUseCase_Create_Call) Run(run func(ctx context.Context, name string, alg domain0.Algorithm)) *MockTransitKeyUseCase_Create_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -753,9 +600,9 @@ func (_c *MockTransitKeyUseCase_Create_Call) Run(run func(ctx context.Context, n if args[1] != nil { arg1 = args[1].(string) } - var arg2 domain.Algorithm + var arg2 domain0.Algorithm if args[2] != nil { - arg2 = args[2].(domain.Algorithm) + arg2 = args[2].(domain0.Algorithm) } run( arg0, @@ -766,34 +613,34 @@ func (_c *MockTransitKeyUseCase_Create_Call) Run(run func(ctx context.Context, n return _c } -func (_c *MockTransitKeyUseCase_Create_Call) Return(transitKey *domain0.TransitKey, err error) *MockTransitKeyUseCase_Create_Call { +func (_c *MockTransitKeyUseCase_Create_Call) Return(transitKey *domain.TransitKey, err error) *MockTransitKeyUseCase_Create_Call { _c.Call.Return(transitKey, err) return _c } -func (_c *MockTransitKeyUseCase_Create_Call) RunAndReturn(run func(ctx context.Context, name string, alg domain.Algorithm) (*domain0.TransitKey, error)) *MockTransitKeyUseCase_Create_Call { +func (_c *MockTransitKeyUseCase_Create_Call) RunAndReturn(run func(ctx context.Context, name string, alg domain0.Algorithm) (*domain.TransitKey, error)) *MockTransitKeyUseCase_Create_Call { _c.Call.Return(run) return _c } // Decrypt provides a mock function for the type MockTransitKeyUseCase -func (_mock *MockTransitKeyUseCase) Decrypt(ctx context.Context, name string, ciphertext string, context1 []byte) (*domain0.EncryptedBlob, error) { +func (_mock *MockTransitKeyUseCase) Decrypt(ctx context.Context, name string, ciphertext string, context1 []byte) (*domain.EncryptedBlob, error) { ret := _mock.Called(ctx, name, ciphertext, context1) if len(ret) == 0 { panic("no return value specified for Decrypt") } - var r0 *domain0.EncryptedBlob + var r0 *domain.EncryptedBlob var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, []byte) (*domain0.EncryptedBlob, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, []byte) (*domain.EncryptedBlob, error)); ok { return returnFunc(ctx, name, ciphertext, context1) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, []byte) *domain0.EncryptedBlob); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, []byte) *domain.EncryptedBlob); ok { r0 = returnFunc(ctx, name, ciphertext, context1) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.EncryptedBlob) + r0 = ret.Get(0).(*domain.EncryptedBlob) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, []byte) error); ok { @@ -846,12 +693,12 @@ func (_c *MockTransitKeyUseCase_Decrypt_Call) Run(run func(ctx context.Context, return _c } -func (_c *MockTransitKeyUseCase_Decrypt_Call) Return(encryptedBlob *domain0.EncryptedBlob, err error) *MockTransitKeyUseCase_Decrypt_Call { +func (_c *MockTransitKeyUseCase_Decrypt_Call) Return(encryptedBlob *domain.EncryptedBlob, err error) *MockTransitKeyUseCase_Decrypt_Call { _c.Call.Return(encryptedBlob, err) return _c } -func (_c *MockTransitKeyUseCase_Decrypt_Call) RunAndReturn(run func(ctx context.Context, name string, ciphertext string, context1 []byte) (*domain0.EncryptedBlob, error)) *MockTransitKeyUseCase_Decrypt_Call { +func (_c *MockTransitKeyUseCase_Decrypt_Call) RunAndReturn(run func(ctx context.Context, name string, ciphertext string, context1 []byte) (*domain.EncryptedBlob, error)) *MockTransitKeyUseCase_Decrypt_Call { _c.Call.Return(run) return _c } @@ -914,23 +761,23 @@ func (_c *MockTransitKeyUseCase_Delete_Call) RunAndReturn(run func(ctx context.C } // Encrypt provides a mock function for the type MockTransitKeyUseCase -func (_mock *MockTransitKeyUseCase) Encrypt(ctx context.Context, name string, plaintext []byte, context1 []byte) (*domain0.EncryptedBlob, error) { +func (_mock *MockTransitKeyUseCase) Encrypt(ctx context.Context, name string, plaintext []byte, context1 []byte) (*domain.EncryptedBlob, error) { ret := _mock.Called(ctx, name, plaintext, context1) if len(ret) == 0 { panic("no return value specified for Encrypt") } - var r0 *domain0.EncryptedBlob + var r0 *domain.EncryptedBlob var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte, []byte) (*domain0.EncryptedBlob, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte, []byte) (*domain.EncryptedBlob, error)); ok { return returnFunc(ctx, name, plaintext, context1) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte, []byte) *domain0.EncryptedBlob); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte, []byte) *domain.EncryptedBlob); ok { r0 = returnFunc(ctx, name, plaintext, context1) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.EncryptedBlob) + r0 = ret.Get(0).(*domain.EncryptedBlob) } } if returnFunc, ok := ret.Get(1).(func(context.Context, string, []byte, []byte) error); ok { @@ -983,41 +830,41 @@ func (_c *MockTransitKeyUseCase_Encrypt_Call) Run(run func(ctx context.Context, return _c } -func (_c *MockTransitKeyUseCase_Encrypt_Call) Return(encryptedBlob *domain0.EncryptedBlob, err error) *MockTransitKeyUseCase_Encrypt_Call { +func (_c *MockTransitKeyUseCase_Encrypt_Call) Return(encryptedBlob *domain.EncryptedBlob, err error) *MockTransitKeyUseCase_Encrypt_Call { _c.Call.Return(encryptedBlob, err) return _c } -func (_c *MockTransitKeyUseCase_Encrypt_Call) RunAndReturn(run func(ctx context.Context, name string, plaintext []byte, context1 []byte) (*domain0.EncryptedBlob, error)) *MockTransitKeyUseCase_Encrypt_Call { +func (_c *MockTransitKeyUseCase_Encrypt_Call) RunAndReturn(run func(ctx context.Context, name string, plaintext []byte, context1 []byte) (*domain.EncryptedBlob, error)) *MockTransitKeyUseCase_Encrypt_Call { _c.Call.Return(run) return _c } // Get provides a mock function for the type MockTransitKeyUseCase -func (_mock *MockTransitKeyUseCase) Get(ctx context.Context, name string, version uint) (*domain0.TransitKey, domain.Algorithm, error) { +func (_mock *MockTransitKeyUseCase) Get(ctx context.Context, name string, version uint) (*domain.TransitKey, domain0.Algorithm, error) { ret := _mock.Called(ctx, name, version) if len(ret) == 0 { panic("no return value specified for Get") } - var r0 *domain0.TransitKey - var r1 domain.Algorithm + var r0 *domain.TransitKey + var r1 domain0.Algorithm var r2 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain0.TransitKey, domain.Algorithm, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) (*domain.TransitKey, domain0.Algorithm, error)); ok { return returnFunc(ctx, name, version) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain0.TransitKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint) *domain.TransitKey); ok { r0 = returnFunc(ctx, name, version) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TransitKey) + r0 = ret.Get(0).(*domain.TransitKey) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) domain.Algorithm); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint) domain0.Algorithm); ok { r1 = returnFunc(ctx, name, version) } else { - r1 = ret.Get(1).(domain.Algorithm) + r1 = ret.Get(1).(domain0.Algorithm) } if returnFunc, ok := ret.Get(2).(func(context.Context, string, uint) error); ok { r2 = returnFunc(ctx, name, version) @@ -1063,34 +910,34 @@ func (_c *MockTransitKeyUseCase_Get_Call) Run(run func(ctx context.Context, name return _c } -func (_c *MockTransitKeyUseCase_Get_Call) Return(transitKey *domain0.TransitKey, algorithm domain.Algorithm, err error) *MockTransitKeyUseCase_Get_Call { +func (_c *MockTransitKeyUseCase_Get_Call) Return(transitKey *domain.TransitKey, algorithm domain0.Algorithm, err error) *MockTransitKeyUseCase_Get_Call { _c.Call.Return(transitKey, algorithm, err) return _c } -func (_c *MockTransitKeyUseCase_Get_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain0.TransitKey, domain.Algorithm, error)) *MockTransitKeyUseCase_Get_Call { +func (_c *MockTransitKeyUseCase_Get_Call) RunAndReturn(run func(ctx context.Context, name string, version uint) (*domain.TransitKey, domain0.Algorithm, error)) *MockTransitKeyUseCase_Get_Call { _c.Call.Return(run) return _c } // ListCursor provides a mock function for the type MockTransitKeyUseCase -func (_mock *MockTransitKeyUseCase) ListCursor(ctx context.Context, afterName *string, limit int) ([]*domain0.TransitKey, error) { +func (_mock *MockTransitKeyUseCase) ListCursor(ctx context.Context, afterName *string, limit int) ([]*domain.TransitKey, error) { ret := _mock.Called(ctx, afterName, limit) if len(ret) == 0 { panic("no return value specified for ListCursor") } - var r0 []*domain0.TransitKey + var r0 []*domain.TransitKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain0.TransitKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) ([]*domain.TransitKey, error)); ok { return returnFunc(ctx, afterName, limit) } - if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain0.TransitKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, *string, int) []*domain.TransitKey); ok { r0 = returnFunc(ctx, afterName, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*domain0.TransitKey) + r0 = ret.Get(0).([]*domain.TransitKey) } } if returnFunc, ok := ret.Get(1).(func(context.Context, *string, int) error); ok { @@ -1137,12 +984,12 @@ func (_c *MockTransitKeyUseCase_ListCursor_Call) Run(run func(ctx context.Contex return _c } -func (_c *MockTransitKeyUseCase_ListCursor_Call) Return(transitKeys []*domain0.TransitKey, err error) *MockTransitKeyUseCase_ListCursor_Call { +func (_c *MockTransitKeyUseCase_ListCursor_Call) Return(transitKeys []*domain.TransitKey, err error) *MockTransitKeyUseCase_ListCursor_Call { _c.Call.Return(transitKeys, err) return _c } -func (_c *MockTransitKeyUseCase_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterName *string, limit int) ([]*domain0.TransitKey, error)) *MockTransitKeyUseCase_ListCursor_Call { +func (_c *MockTransitKeyUseCase_ListCursor_Call) RunAndReturn(run func(ctx context.Context, afterName *string, limit int) ([]*domain.TransitKey, error)) *MockTransitKeyUseCase_ListCursor_Call { _c.Call.Return(run) return _c } @@ -1220,26 +1067,26 @@ func (_c *MockTransitKeyUseCase_PurgeDeleted_Call) RunAndReturn(run func(ctx con } // Rotate provides a mock function for the type MockTransitKeyUseCase -func (_mock *MockTransitKeyUseCase) Rotate(ctx context.Context, name string, alg domain.Algorithm) (*domain0.TransitKey, error) { +func (_mock *MockTransitKeyUseCase) Rotate(ctx context.Context, name string, alg domain0.Algorithm) (*domain.TransitKey, error) { ret := _mock.Called(ctx, name, alg) if len(ret) == 0 { panic("no return value specified for Rotate") } - var r0 *domain0.TransitKey + var r0 *domain.TransitKey var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain.Algorithm) (*domain0.TransitKey, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain0.Algorithm) (*domain.TransitKey, error)); ok { return returnFunc(ctx, name, alg) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain.Algorithm) *domain0.TransitKey); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, domain0.Algorithm) *domain.TransitKey); ok { r0 = returnFunc(ctx, name, alg) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*domain0.TransitKey) + r0 = ret.Get(0).(*domain.TransitKey) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, domain.Algorithm) error); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, domain0.Algorithm) error); ok { r1 = returnFunc(ctx, name, alg) } else { r1 = ret.Error(1) @@ -1255,12 +1102,12 @@ type MockTransitKeyUseCase_Rotate_Call struct { // Rotate is a helper method to define mock.On call // - ctx context.Context // - name string -// - alg domain.Algorithm +// - alg domain0.Algorithm func (_e *MockTransitKeyUseCase_Expecter) Rotate(ctx interface{}, name interface{}, alg interface{}) *MockTransitKeyUseCase_Rotate_Call { return &MockTransitKeyUseCase_Rotate_Call{Call: _e.mock.On("Rotate", ctx, name, alg)} } -func (_c *MockTransitKeyUseCase_Rotate_Call) Run(run func(ctx context.Context, name string, alg domain.Algorithm)) *MockTransitKeyUseCase_Rotate_Call { +func (_c *MockTransitKeyUseCase_Rotate_Call) Run(run func(ctx context.Context, name string, alg domain0.Algorithm)) *MockTransitKeyUseCase_Rotate_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -1270,9 +1117,9 @@ func (_c *MockTransitKeyUseCase_Rotate_Call) Run(run func(ctx context.Context, n if args[1] != nil { arg1 = args[1].(string) } - var arg2 domain.Algorithm + var arg2 domain0.Algorithm if args[2] != nil { - arg2 = args[2].(domain.Algorithm) + arg2 = args[2].(domain0.Algorithm) } run( arg0, @@ -1283,12 +1130,12 @@ func (_c *MockTransitKeyUseCase_Rotate_Call) Run(run func(ctx context.Context, n return _c } -func (_c *MockTransitKeyUseCase_Rotate_Call) Return(transitKey *domain0.TransitKey, err error) *MockTransitKeyUseCase_Rotate_Call { +func (_c *MockTransitKeyUseCase_Rotate_Call) Return(transitKey *domain.TransitKey, err error) *MockTransitKeyUseCase_Rotate_Call { _c.Call.Return(transitKey, err) return _c } -func (_c *MockTransitKeyUseCase_Rotate_Call) RunAndReturn(run func(ctx context.Context, name string, alg domain.Algorithm) (*domain0.TransitKey, error)) *MockTransitKeyUseCase_Rotate_Call { +func (_c *MockTransitKeyUseCase_Rotate_Call) RunAndReturn(run func(ctx context.Context, name string, alg domain0.Algorithm) (*domain.TransitKey, error)) *MockTransitKeyUseCase_Rotate_Call { _c.Call.Return(run) return _c } diff --git a/internal/transit/usecase/transit_key_usecase.go b/internal/transit/usecase/transit_key_usecase.go index 87f6daf..336d99e 100644 --- a/internal/transit/usecase/transit_key_usecase.go +++ b/internal/transit/usecase/transit_key_usecase.go @@ -1,7 +1,7 @@ // Package usecase implements transit encryption business logic. // -// Coordinates between cryptographic services and repositories to manage transit keys -// with versioning and envelope encryption. Uses TxManager for transactional consistency. +// Coordinates between the keyring and the transit repository to manage transit keys +// with versioning and envelope encryption. package usecase import ( @@ -11,33 +11,26 @@ import ( "github.com/google/uuid" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - cryptoService "github.com/allisson/secrets/internal/crypto/service" "github.com/allisson/secrets/internal/database" apperrors "github.com/allisson/secrets/internal/errors" + "github.com/allisson/secrets/internal/keyring" transitDomain "github.com/allisson/secrets/internal/transit/domain" ) +// nonceSize is the AEAD nonce length stored alongside ciphertext in the +// transit wire format. Both supported algorithms (AES-256-GCM, +// ChaCha20-Poly1305) use 12-byte nonces. If we add an algorithm with a +// different nonce size, this needs to be derived from the algorithm. +const nonceSize = 12 + // transitKeyUseCase implements TransitKeyUseCase for managing transit keys. type transitKeyUseCase struct { txManager database.TxManager transitRepo TransitKeyRepository - dekRepo DekRepository - keyManager cryptoService.KeyManager - aeadManager cryptoService.AEADManager - kekChain *cryptoDomain.KekChain -} - -// getKek retrieves a KEK from the chain by its ID. -func (t *transitKeyUseCase) getKek(kekID uuid.UUID) (*cryptoDomain.Kek, error) { - kek, ok := t.kekChain.Get(kekID) - if !ok { - return nil, cryptoDomain.ErrKekNotFound - } - return kek, nil + keyring keyring.Keyring } // Create generates and persists a new transit key with version 1. -// Returns ErrTransitKeyAlreadyExists if a transit key with the same name already exists. func (t *transitKeyUseCase) Create( ctx context.Context, name string, @@ -46,47 +39,28 @@ func (t *transitKeyUseCase) Create( var transitKey *transitDomain.TransitKey err := t.txManager.WithTx(ctx, func(txCtx context.Context) error { - // Check if transit key with version 1 already exists existingKey, err := t.transitRepo.GetByNameAndVersion(txCtx, name, 1) if err != nil && !apperrors.Is(err, transitDomain.ErrTransitKeyNotFound) { - // Return unexpected database errors return err } if existingKey != nil { - // Transit key already exists with version 1 return transitDomain.ErrTransitKeyAlreadyExists } - // Get active KEK from chain - activeKek, err := t.getKek(t.kekChain.ActiveKekID()) - if err != nil { - return err - } - - // Create DEK encrypted with active KEK - dek, err := t.keyManager.CreateDek(activeKek, alg) + handle, err := t.keyring.AllocateDek(txCtx, alg) if err != nil { return err } - // Persist DEK to database - if err := t.dekRepo.Create(txCtx, &dek); err != nil { - return err - } - - // Create transit key with version 1 transitKey = &transitDomain.TransitKey{ ID: uuid.Must(uuid.NewV7()), Name: name, Version: 1, - DekID: dek.ID, + DekID: handle.DekID, CreatedAt: time.Now().UTC(), } - - // Persist transit key return t.transitRepo.Create(txCtx, transitKey) }) - if err != nil { return nil, err } @@ -103,10 +77,8 @@ func (t *transitKeyUseCase) Rotate( var newTransitKey *transitDomain.TransitKey err := t.txManager.WithTx(ctx, func(txCtx context.Context) error { - // Get latest transit key version currentKey, err := t.transitRepo.GetByName(txCtx, name) if err != nil { - // If key doesn't exist, create first version if apperrors.Is(err, transitDomain.ErrTransitKeyNotFound) { newTransitKey, err = t.Create(txCtx, name, alg) return err @@ -114,36 +86,20 @@ func (t *transitKeyUseCase) Rotate( return err } - // Get active KEK from chain - activeKek, err := t.getKek(t.kekChain.ActiveKekID()) - if err != nil { - return err - } - - // Create new DEK encrypted with active KEK - dek, err := t.keyManager.CreateDek(activeKek, alg) + handle, err := t.keyring.AllocateDek(txCtx, alg) if err != nil { return err } - // Persist new DEK - if err := t.dekRepo.Create(txCtx, &dek); err != nil { - return err - } - - // Create new transit key with incremented version newTransitKey = &transitDomain.TransitKey{ ID: uuid.Must(uuid.NewV7()), Name: name, Version: currentKey.Version + 1, - DekID: dek.ID, + DekID: handle.DekID, CreatedAt: time.Now().UTC(), } - - // Persist new transit key return t.transitRepo.Create(txCtx, newTransitKey) }) - if err != nil { return nil, err } @@ -166,50 +122,25 @@ func (t *transitKeyUseCase) Delete(ctx context.Context, name string) error { } // Encrypt encrypts plaintext using the latest version of a named transit key. +// +// The returned EncryptedBlob.Ciphertext is `nonce || ciphertext`, base64-encoded +// in the wire format `version:base64(...)`. See ADR-0002. func (t *transitKeyUseCase) Encrypt( ctx context.Context, name string, plaintext, context []byte, ) (*transitDomain.EncryptedBlob, error) { - // Get latest transit key version transitKey, err := t.transitRepo.GetByName(ctx, name) if err != nil { return nil, err } - // Get DEK by transit key's DekID - dek, err := t.dekRepo.Get(ctx, transitKey.DekID) - if err != nil { - return nil, err - } - - // Get KEK for decrypting DEK - kek, err := t.getKek(dek.KekID) - if err != nil { - return nil, err - } - - // Decrypt DEK with KEK - dekKey, err := t.keyManager.DecryptDek(dek, kek) - if err != nil { - return nil, err - } - defer cryptoDomain.Zero(dekKey) - - // Create AEAD cipher with decrypted DEK - cipher, err := t.aeadManager.CreateCipher(dekKey, dek.Algorithm) - if err != nil { - return nil, err - } - - // Encrypt plaintext with optional context - ciphertext, nonce, err := cipher.Encrypt(plaintext, context) + handle := keyring.DekHandle{DekID: transitKey.DekID} + ciphertext, nonce, err := t.keyring.EncryptWith(ctx, handle, plaintext, context) if err != nil { return nil, apperrors.Wrap(err, "failed to encrypt plaintext") } - // Combine ciphertext and nonce (nonce is prepended to ciphertext by AEAD) - // The AEAD Encrypt returns ciphertext with authentication tag, we need to store nonce separately encryptedData := make([]byte, 0, len(nonce)+len(ciphertext)) encryptedData = append(encryptedData, nonce...) encryptedData = append(encryptedData, ciphertext...) @@ -228,55 +159,24 @@ func (t *transitKeyUseCase) Decrypt( ciphertext string, context []byte, ) (*transitDomain.EncryptedBlob, error) { - // Parse encrypted blob from ciphertext string format "version:base64..." blob, err := transitDomain.NewEncryptedBlob(ciphertext) if err != nil { return nil, err } - // Get transit key by name and version from blob transitKey, err := t.transitRepo.GetByNameAndVersion(ctx, name, blob.Version) if err != nil { return nil, err } - // Get DEK by transit key's DekID - dek, err := t.dekRepo.Get(ctx, transitKey.DekID) - if err != nil { - return nil, err - } - - // Get KEK for decrypting DEK - kek, err := t.getKek(dek.KekID) - if err != nil { - return nil, err - } - - // Decrypt DEK with KEK - dekKey, err := t.keyManager.DecryptDek(dek, kek) - if err != nil { - return nil, err - } - defer cryptoDomain.Zero(dekKey) - - // Create AEAD cipher with decrypted DEK - cipher, err := t.aeadManager.CreateCipher(dekKey, dek.Algorithm) - if err != nil { - return nil, err - } - - // Extract nonce and ciphertext from encrypted data - // The nonce is prepended to the ciphertext - nonceSize := cipher.NonceSize() if len(blob.Ciphertext) < nonceSize { return nil, apperrors.Wrap(cryptoDomain.ErrDecryptionFailed, "ciphertext too short") } - nonce := blob.Ciphertext[:nonceSize] encryptedData := blob.Ciphertext[nonceSize:] - // Decrypt ciphertext with optional context - plaintext, err := cipher.Decrypt(encryptedData, nonce, context) + handle := keyring.DekHandle{DekID: transitKey.DekID} + plaintext, err := t.keyring.DecryptWith(ctx, handle, encryptedData, nonce, context) if err != nil { return nil, cryptoDomain.ErrDecryptionFailed } @@ -307,21 +207,15 @@ func (t *transitKeyUseCase) PurgeDeleted(ctx context.Context, olderThanDays int, return t.transitRepo.HardDelete(ctx, olderThan, dryRun) } -// NewTransitKeyUseCase creates a new TransitKeyUseCase with injected dependencies. +// NewTransitKeyUseCase creates a new TransitKeyUseCase backed by a Keyring. func NewTransitKeyUseCase( txManager database.TxManager, transitRepo TransitKeyRepository, - dekRepo DekRepository, - keyManager cryptoService.KeyManager, - aeadManager cryptoService.AEADManager, - kekChain *cryptoDomain.KekChain, + kr keyring.Keyring, ) TransitKeyUseCase { return &transitKeyUseCase{ txManager: txManager, transitRepo: transitRepo, - dekRepo: dekRepo, - keyManager: keyManager, - aeadManager: aeadManager, - kekChain: kekChain, + keyring: kr, } } diff --git a/internal/transit/usecase/transit_key_usecase_test.go b/internal/transit/usecase/transit_key_usecase_test.go index a57c271..5a51baf 100644 --- a/internal/transit/usecase/transit_key_usecase_test.go +++ b/internal/transit/usecase/transit_key_usecase_test.go @@ -1,1536 +1,220 @@ -package usecase +package usecase_test import ( "context" - "errors" + "encoding/base64" + "fmt" "testing" - "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" cryptoDomain "github.com/allisson/secrets/internal/crypto/domain" - serviceMocks "github.com/allisson/secrets/internal/crypto/service/mocks" - databaseMocks "github.com/allisson/secrets/internal/database/mocks" - apperrors "github.com/allisson/secrets/internal/errors" + "github.com/allisson/secrets/internal/keyring" transitDomain "github.com/allisson/secrets/internal/transit/domain" - usecaseMocks "github.com/allisson/secrets/internal/transit/usecase/mocks" + "github.com/allisson/secrets/internal/transit/usecase" + "github.com/allisson/secrets/internal/transit/usecase/mocks" ) -// Helper function to create a test KEK chain -func createTestKekChain(activeKekID uuid.UUID, kek *cryptoDomain.Kek) *cryptoDomain.KekChain { - keks := []*cryptoDomain.Kek{kek} - return cryptoDomain.NewKekChain(keks) -} +// noopTxManager runs the function with no real transaction. +type noopTxManager struct{} -// Helper function to create a test KEK -func createTestKek() *cryptoDomain.Kek { - return &cryptoDomain.Kek{ - ID: uuid.Must(uuid.NewV7()), - MasterKeyID: "test-master-key", - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-kek"), - Key: make([]byte, 32), - Nonce: []byte("nonce"), - Version: 1, - CreatedAt: time.Now().UTC(), - } +func (noopTxManager) WithTx(ctx context.Context, fn func(ctx context.Context) error) error { + return fn(ctx) } -// Helper function to create a test DEK -func createTestDek(kekID uuid.UUID) *cryptoDomain.Dek { - return &cryptoDomain.Dek{ - ID: uuid.Must(uuid.NewV7()), - KekID: kekID, - Algorithm: cryptoDomain.AESGCM, - EncryptedKey: []byte("encrypted-dek"), - Nonce: []byte("nonce"), - CreatedAt: time.Now().UTC(), - } +func newTransitKeyUseCase( + t *testing.T, +) (usecase.TransitKeyUseCase, *keyring.Fake, *mocks.MockTransitKeyRepository) { + t.Helper() + fake := keyring.NewFake() + repo := mocks.NewMockTransitKeyRepository(t) + uc := usecase.NewTransitKeyUseCase(noopTxManager{}, repo, fake) + return uc, fake, repo } -// Helper function to create a test transit key -func createTestTransitKey(name string, version uint, dekID uuid.UUID) *transitDomain.TransitKey { - return &transitDomain.TransitKey{ - ID: uuid.Must(uuid.NewV7()), - Name: name, - Version: version, - DekID: dekID, - CreatedAt: time.Now().UTC(), - } +// allocateDekForTest seeds the keyring Fake with a DekID and returns it. +func allocateDekForTest(t *testing.T, fake *keyring.Fake) uuid.UUID { + t.Helper() + handle, err := fake.AllocateDek(context.Background(), keyring.AESGCM) + require.NoError(t, err) + return handle.DekID } -// TestTransitKeyUseCase_Create tests the Create method of transitKeyUseCase. func TestTransitKeyUseCase_Create(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_CreateTransitKeyWithAESGCM", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - expectedDek := createTestDek(kek.ID) - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - RunAndReturn(func(ctx context.Context, fn func(context.Context) error) error { - return fn(ctx) - }). - Once() - - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, transitDomain.ErrTransitKeyNotFound). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(*expectedDek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(ctx, mock.MatchedBy(func(dek *cryptoDomain.Dek) bool { - return dek.ID == expectedDek.ID && dek.KekID == expectedDek.KekID - })). - Return(nil). - Once() - - mockTransitRepo.EXPECT(). - Create(ctx, mock.MatchedBy(func(tk *transitDomain.TransitKey) bool { - return tk.Name == "test-key" && tk.Version == 1 && tk.DekID == expectedDek.ID + t.Run("Success", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTransitKeyUseCase(t) + repo.EXPECT(). + GetByNameAndVersion(ctx, "k", uint(1)). + Return(nil, transitDomain.ErrTransitKeyNotFound) + repo.EXPECT(). + Create(ctx, mock.MatchedBy(func(k *transitDomain.TransitKey) bool { + return k.Name == "k" && k.Version == 1 })). - Return(nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Create(ctx, "test-key", cryptoDomain.AESGCM) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, transitKey) - assert.Equal(t, "test-key", transitKey.Name) - assert.Equal(t, uint(1), transitKey.Version) - assert.Equal(t, expectedDek.ID, transitKey.DekID) - }) - - t.Run("Success_CreateTransitKeyWithChaCha20", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - expectedDek := createTestDek(kek.ID) - expectedDek.Algorithm = cryptoDomain.ChaCha20 - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - RunAndReturn(func(ctx context.Context, fn func(context.Context) error) error { - return fn(ctx) - }). - Once() - - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, transitDomain.ErrTransitKeyNotFound). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.ChaCha20). - Return(*expectedDek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(ctx, mock.Anything). - Return(nil). - Once() - - mockTransitRepo.EXPECT(). - Create(ctx, mock.Anything). - Return(nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Create(ctx, "test-key", cryptoDomain.ChaCha20) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, transitKey) - assert.Equal(t, uint(1), transitKey.Version) - }) - - t.Run("Error_TransitKeyAlreadyExists", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - existingTransitKey := createTestTransitKey("test-key", 1, uuid.Must(uuid.NewV7())) - - // Setup expectations - GetByNameAndVersion should return existing key - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - RunAndReturn(func(ctx context.Context, fn func(context.Context) error) error { - return fn(ctx) - }). - Once() - - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(existingTransitKey, nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Create(ctx, "test-key", cryptoDomain.AESGCM) - - // Assert - assert.Error(t, err) - assert.Nil(t, transitKey) - assert.True(t, apperrors.Is(err, transitDomain.ErrTransitKeyAlreadyExists)) - assert.True(t, apperrors.Is(err, apperrors.ErrConflict)) - }) + Return(nil) - t.Run("Error_DekCreationFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - expectedError := errors.New("dek creation failed") - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - RunAndReturn(func(ctx context.Context, fn func(context.Context) error) error { - return fn(ctx) - }). - Once() - - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, transitDomain.ErrTransitKeyNotFound). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(cryptoDomain.Dek{}, expectedError). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Create(ctx, "test-key", cryptoDomain.AESGCM) - - // Assert - assert.Error(t, err) - assert.Nil(t, transitKey) - assert.Equal(t, expectedError, err) + key, err := uc.Create(ctx, "k", cryptoDomain.AESGCM) + require.NoError(t, err) + assert.Equal(t, "k", key.Name) + assert.EqualValues(t, 1, key.Version) + assert.NotEqual(t, uuid.Nil, key.DekID) }) - t.Run("Error_DekPersistenceFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() + t.Run("Error_AlreadyExists", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTransitKeyUseCase(t) + repo.EXPECT(). + GetByNameAndVersion(ctx, "dup", uint(1)). + Return(&transitDomain.TransitKey{}, nil) - expectedDek := createTestDek(kek.ID) - expectedError := errors.New("database error") - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - RunAndReturn(func(ctx context.Context, fn func(context.Context) error) error { - return fn(ctx) - }). - Once() - - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, transitDomain.ErrTransitKeyNotFound). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(*expectedDek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(ctx, mock.Anything). - Return(expectedError). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Create(ctx, "test-key", cryptoDomain.AESGCM) - - // Assert - assert.Error(t, err) - assert.Nil(t, transitKey) - assert.Equal(t, expectedError, err) - }) - - t.Run("Error_TransitKeyPersistenceFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - expectedDek := createTestDek(kek.ID) - expectedError := errors.New("database error") - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - RunAndReturn(func(ctx context.Context, fn func(context.Context) error) error { - return fn(ctx) - }). - Once() - - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, transitDomain.ErrTransitKeyNotFound). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(*expectedDek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(ctx, mock.Anything). - Return(nil). - Once() - - mockTransitRepo.EXPECT(). - Create(ctx, mock.Anything). - Return(expectedError). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Create(ctx, "test-key", cryptoDomain.AESGCM) - - // Assert - assert.Error(t, err) - assert.Nil(t, transitKey) - assert.Equal(t, expectedError, err) + _, err := uc.Create(ctx, "dup", cryptoDomain.AESGCM) + assert.ErrorIs(t, err, transitDomain.ErrTransitKeyAlreadyExists) }) } -// TestTransitKeyUseCase_Rotate tests the Rotate method of transitKeyUseCase. func TestTransitKeyUseCase_Rotate(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_RotateToNewVersion", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - existingDek := createTestDek(kek.ID) - currentKey := createTestTransitKey("test-key", 1, existingDek.ID) - newDek := createTestDek(kek.ID) - - // Setup expectations for transaction - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(nil). - Once() - - mockTransitRepo.EXPECT(). - GetByName(mock.Anything, "test-key"). - Return(currentKey, nil). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(*newDek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(dek *cryptoDomain.Dek) bool { - return dek.ID == newDek.ID - })). - Return(nil). - Once() - - mockTransitRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(tk *transitDomain.TransitKey) bool { - return tk.Name == "test-key" && tk.Version == 2 && tk.DekID == newDek.ID + t.Run("Success_IncrementsVersion", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTransitKeyUseCase(t) + repo.EXPECT().GetByName(ctx, "k").Return(&transitDomain.TransitKey{ + Name: "k", + Version: 2, + }, nil) + repo.EXPECT(). + Create(ctx, mock.MatchedBy(func(k *transitDomain.TransitKey) bool { + return k.Version == 3 })). - Return(nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Rotate(ctx, "test-key", cryptoDomain.AESGCM) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, transitKey) - assert.Equal(t, "test-key", transitKey.Name) - assert.Equal(t, uint(2), transitKey.Version) - assert.Equal(t, newDek.ID, transitKey.DekID) - }) - - t.Run("Success_RotateCreatesFirstKeyIfNoneExist", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - newDek := createTestDek(kek.ID) - - // Setup expectations for transaction - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(nil). - Times(2) - - mockTransitRepo.EXPECT(). - GetByName(mock.Anything, "test-key"). - Return(nil, transitDomain.ErrTransitKeyNotFound). - Once() - - mockTransitRepo.EXPECT(). - GetByNameAndVersion(mock.Anything, "test-key", uint(1)). - Return(nil, transitDomain.ErrTransitKeyNotFound). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(*newDek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, mock.Anything). - Return(nil). - Once() - - mockTransitRepo.EXPECT(). - Create(mock.Anything, mock.MatchedBy(func(tk *transitDomain.TransitKey) bool { - return tk.Name == "test-key" && tk.Version == 1 + Return(nil) + + key, err := uc.Rotate(ctx, "k", cryptoDomain.AESGCM) + require.NoError(t, err) + assert.EqualValues(t, 3, key.Version) + }) + + t.Run("Success_CreatesFirstVersionWhenAbsent", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTransitKeyUseCase(t) + repo.EXPECT().GetByName(ctx, "new").Return(nil, transitDomain.ErrTransitKeyNotFound) + repo.EXPECT(). + GetByNameAndVersion(ctx, "new", uint(1)). + Return(nil, transitDomain.ErrTransitKeyNotFound) + repo.EXPECT(). + Create(ctx, mock.MatchedBy(func(k *transitDomain.TransitKey) bool { + return k.Version == 1 })). - Return(nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Rotate(ctx, "test-key", cryptoDomain.AESGCM) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, transitKey) - assert.Equal(t, uint(1), transitKey.Version) - }) - - t.Run("Success_RotateWithDifferentAlgorithm", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - existingDek := createTestDek(kek.ID) - currentKey := createTestTransitKey("test-key", 1, existingDek.ID) - newDek := createTestDek(kek.ID) - newDek.Algorithm = cryptoDomain.ChaCha20 - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(nil). - Once() - - mockTransitRepo.EXPECT(). - GetByName(mock.Anything, "test-key"). - Return(currentKey, nil). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.ChaCha20). - Return(*newDek, nil). - Once() - - mockDekRepo.EXPECT(). - Create(mock.Anything, mock.Anything). - Return(nil). - Once() - - mockTransitRepo.EXPECT(). - Create(mock.Anything, mock.Anything). - Return(nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Rotate(ctx, "test-key", cryptoDomain.ChaCha20) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, transitKey) - assert.Equal(t, uint(2), transitKey.Version) - }) - - t.Run("Error_GetByNameFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - expectedError := errors.New("database error") - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(expectedError). - Once() - - mockTransitRepo.EXPECT(). - GetByName(mock.Anything, "test-key"). - Return(nil, expectedError). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Rotate(ctx, "test-key", cryptoDomain.AESGCM) - - // Assert - assert.Error(t, err) - assert.Nil(t, transitKey) - }) - - t.Run("Error_DekCreationFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - existingDek := createTestDek(kek.ID) - currentKey := createTestTransitKey("test-key", 1, existingDek.ID) - expectedError := errors.New("dek creation failed") - - // Setup expectations - mockTxManager.EXPECT(). - WithTx(ctx, mock.AnythingOfType("func(context.Context) error")). - Run(func(ctx context.Context, fn func(context.Context) error) { - _ = fn(ctx) - }). - Return(expectedError). - Once() + Return(nil) - mockTransitRepo.EXPECT(). - GetByName(mock.Anything, "test-key"). - Return(currentKey, nil). - Once() - - mockKeyManager.EXPECT(). - CreateDek(kek, cryptoDomain.AESGCM). - Return(cryptoDomain.Dek{}, expectedError). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - transitKey, err := uc.Rotate(ctx, "test-key", cryptoDomain.AESGCM) - - // Assert - assert.Error(t, err) - assert.Nil(t, transitKey) + key, err := uc.Rotate(ctx, "new", cryptoDomain.AESGCM) + require.NoError(t, err) + assert.EqualValues(t, 1, key.Version) }) } -// TestTransitKeyUseCase_Delete tests the Delete method of transitKeyUseCase. -func TestTransitKeyUseCase_Delete(t *testing.T) { +func TestTransitKeyUseCase_EncryptDecrypt_RoundTrip(t *testing.T) { + t.Parallel() ctx := context.Background() - t.Run("Success_SoftDeleteTransitKey", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - name := "test-key" - - // Setup expectations - mockTransitRepo.EXPECT(). - Delete(ctx, name). - Return(nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - err := uc.Delete(ctx, name) + uc, fake, repo := newTransitKeyUseCase(t) + dekID := allocateDekForTest(t, fake) - // Assert - assert.NoError(t, err) - }) - - t.Run("Error_DeleteFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) + transitKey := &transitDomain.TransitKey{ + Name: "k", + Version: 1, + DekID: dekID, + } + plaintext := []byte("payload") + aad := []byte("context") - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() + repo.EXPECT().GetByName(ctx, "k").Return(transitKey, nil) - name := "test-key" - expectedError := errors.New("database error") + encBlob, err := uc.Encrypt(ctx, "k", plaintext, aad) + require.NoError(t, err) + require.NotEmpty(t, encBlob.Ciphertext) - // Setup expectations - mockTransitRepo.EXPECT(). - Delete(ctx, name). - Return(expectedError). - Once() + // Roundtrip via the wire format. + wire := fmt.Sprintf("%d:%s", encBlob.Version, base64.StdEncoding.EncodeToString(encBlob.Ciphertext)) - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - err := uc.Delete(ctx, name) + repo.EXPECT().GetByNameAndVersion(ctx, "k", uint(1)).Return(transitKey, nil) - // Assert - assert.Error(t, err) - assert.Equal(t, expectedError, err) - }) + decBlob, err := uc.Decrypt(ctx, "k", wire, aad) + require.NoError(t, err) + assert.Equal(t, plaintext, decBlob.Plaintext) } -// TestTransitKeyUseCase_Encrypt tests the Encrypt method of transitKeyUseCase. -func TestTransitKeyUseCase_Encrypt(t *testing.T) { +func TestTransitKeyUseCase_Decrypt_BadFormat(t *testing.T) { + t.Parallel() ctx := context.Background() + uc, _, _ := newTransitKeyUseCase(t) - t.Run("Success_EncryptPlaintext", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - mockCipher := serviceMocks.NewMockAEAD(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - dek := createTestDek(kek.ID) - transitKey := createTestTransitKey("test-key", 1, dek.ID) - plaintext := []byte("sensitive data") - dekKey := make([]byte, 32) - ciphertext := []byte("encrypted-data") - nonce := []byte("random-nonce") - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(transitKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dek.ID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAeadManager.EXPECT(). - CreateCipher(dekKey, dek.Algorithm). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - NonceSize(). - Return(12). - Maybe() - - mockCipher.EXPECT(). - Encrypt(plaintext, mock.Anything). - Return(ciphertext, nonce, nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - blob, err := uc.Encrypt(ctx, "test-key", plaintext, nil) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, blob) - assert.Equal(t, uint(1), blob.Version) - assert.NotNil(t, blob.Ciphertext) - assert.Nil(t, blob.Plaintext) - }) - - t.Run("Error_TransitKeyNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - plaintext := []byte("sensitive data") - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(nil, transitDomain.ErrTransitKeyNotFound). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - blob, err := uc.Encrypt(ctx, "test-key", plaintext, nil) - - // Assert - assert.Error(t, err) - assert.Nil(t, blob) - assert.True(t, apperrors.Is(err, transitDomain.ErrTransitKeyNotFound)) - }) - - t.Run("Error_DekNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - dek := createTestDek(kek.ID) - transitKey := createTestTransitKey("test-key", 1, dek.ID) - plaintext := []byte("sensitive data") - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(transitKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dek.ID). - Return(nil, cryptoDomain.ErrDekNotFound). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - blob, err := uc.Encrypt(ctx, "test-key", plaintext, nil) - - // Assert - assert.Error(t, err) - assert.Nil(t, blob) - assert.True(t, apperrors.Is(err, cryptoDomain.ErrDekNotFound)) - }) - - t.Run("Error_DecryptionFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - dek := createTestDek(kek.ID) - transitKey := createTestTransitKey("test-key", 1, dek.ID) - plaintext := []byte("sensitive data") - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(transitKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dek.ID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(nil, cryptoDomain.ErrDecryptionFailed). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - blob, err := uc.Encrypt(ctx, "test-key", plaintext, nil) - - // Assert - assert.Error(t, err) - assert.Nil(t, blob) - assert.True(t, apperrors.Is(err, cryptoDomain.ErrDecryptionFailed)) - }) - - t.Run("Success_EncryptWithContext", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - mockCipher := serviceMocks.NewMockAEAD(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - dek := createTestDek(kek.ID) - transitKey := createTestTransitKey("test-key", 1, dek.ID) - plaintext := []byte("sensitive data") - contextAAD := []byte("aead context") - dekKey := make([]byte, 32) - ciphertext := []byte("encrypted-data") - nonce := []byte("random-nonce") - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByName(ctx, "test-key"). - Return(transitKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dek.ID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAeadManager.EXPECT(). - CreateCipher(dekKey, dek.Algorithm). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - Encrypt(plaintext, contextAAD). - Return(ciphertext, nonce, nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - blob, err := uc.Encrypt(ctx, "test-key", plaintext, contextAAD) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, blob) - mockCipher.AssertExpectations(t) - }) + _, err := uc.Decrypt(ctx, "k", "not-a-valid-wire-format", nil) + assert.Error(t, err) } -// TestTransitKeyUseCase_Decrypt tests the Decrypt method of transitKeyUseCase. -func TestTransitKeyUseCase_Decrypt(t *testing.T) { +func TestTransitKeyUseCase_Decrypt_CiphertextTooShort(t *testing.T) { + t.Parallel() ctx := context.Background() + uc, _, repo := newTransitKeyUseCase(t) - t.Run("Success_DecryptCiphertext", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - mockCipher := serviceMocks.NewMockAEAD(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - dek := createTestDek(kek.ID) - transitKey := createTestTransitKey("test-key", 1, dek.ID) - plaintext := []byte("sensitive data") - dekKey := make([]byte, 32) - - // Create a valid encrypted blob string - nonce := []byte("012345678901") // 12 bytes - ciphertext := []byte("encrypted-data-with-tag") - //nolint:gocritic // intentionally creating new slice for test data - encryptedData := append(nonce, ciphertext...) - blob := transitDomain.EncryptedBlob{ - Version: 1, - Ciphertext: encryptedData, - } - ciphertextStr := blob.String() - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(transitKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dek.ID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAeadManager.EXPECT(). - CreateCipher(dekKey, dek.Algorithm). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - NonceSize(). - Return(12). - Maybe() - - mockCipher.EXPECT(). - Decrypt(ciphertext, nonce, mock.Anything). - Return(plaintext, nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - resultBlob, err := uc.Decrypt(ctx, "test-key", ciphertextStr, nil) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, resultBlob) - assert.Equal(t, uint(1), resultBlob.Version) - assert.Equal(t, plaintext, resultBlob.Plaintext) - assert.Nil(t, resultBlob.Ciphertext) - }) - - t.Run("Success_DecryptWithOldVersion", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - mockCipher := serviceMocks.NewMockAEAD(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - dek := createTestDek(kek.ID) - transitKey := createTestTransitKey("test-key", 5, dek.ID) - plaintext := []byte("old data") - dekKey := make([]byte, 32) - - // Create a valid encrypted blob string with version 5 - nonce := []byte("012345678901") - ciphertext := []byte("old-encrypted-data") - //nolint:gocritic // intentionally creating new slice for test data - encryptedData := append(nonce, ciphertext...) - blob := transitDomain.EncryptedBlob{ - Version: 5, - Ciphertext: encryptedData, - } - ciphertextStr := blob.String() - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(5)). - Return(transitKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dek.ID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAeadManager.EXPECT(). - CreateCipher(dekKey, dek.Algorithm). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - NonceSize(). - Return(12). - Maybe() - - mockCipher.EXPECT(). - Decrypt(ciphertext, nonce, mock.Anything). - Return(plaintext, nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - resultBlob, err := uc.Decrypt(ctx, "test-key", ciphertextStr, nil) - - // Assert - assert.NoError(t, err) - assert.NotNil(t, resultBlob) - assert.Equal(t, uint(5), resultBlob.Version) - assert.Equal(t, plaintext, resultBlob.Plaintext) - }) - - t.Run("Error_InvalidBlobFormat", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - invalidCiphertext := "invalid-blob-format" - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - blob, err := uc.Decrypt(ctx, "test-key", invalidCiphertext, nil) - - // Assert - assert.Error(t, err) - assert.Nil(t, blob) - assert.True(t, apperrors.Is(err, transitDomain.ErrInvalidBlobFormat)) - }) - - t.Run("Error_TransitKeyNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - // Create a valid encrypted blob string - blob := transitDomain.EncryptedBlob{ - Version: 1, - Ciphertext: []byte("data"), - } - ciphertextStr := blob.String() - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(nil, transitDomain.ErrTransitKeyNotFound). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - resultBlob, err := uc.Decrypt(ctx, "test-key", ciphertextStr, nil) + repo.EXPECT(). + GetByNameAndVersion(ctx, "k", uint(1)). + Return(&transitDomain.TransitKey{Name: "k", Version: 1, DekID: uuid.New()}, nil) - // Assert - assert.Error(t, err) - assert.Nil(t, resultBlob) - assert.True(t, apperrors.Is(err, transitDomain.ErrTransitKeyNotFound)) - }) - - t.Run("Error_CiphertextTooShort", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - mockCipher := serviceMocks.NewMockAEAD(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - dek := createTestDek(kek.ID) - transitKey := createTestTransitKey("test-key", 1, dek.ID) - dekKey := make([]byte, 32) - - // Create a blob with data shorter than nonce size (12 bytes) - blob := transitDomain.EncryptedBlob{ - Version: 1, - Ciphertext: []byte("short"), // Only 5 bytes - } - ciphertextStr := blob.String() - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(transitKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dek.ID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAeadManager.EXPECT(). - CreateCipher(dekKey, dek.Algorithm). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - NonceSize(). - Return(12). - Maybe() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - resultBlob, err := uc.Decrypt(ctx, "test-key", ciphertextStr, nil) - - // Assert - assert.Error(t, err) - assert.Nil(t, resultBlob) - assert.True(t, apperrors.Is(err, cryptoDomain.ErrDecryptionFailed)) - }) - - t.Run("Error_DecryptWithWrongContext", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - mockCipher := serviceMocks.NewMockAEAD(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - dek := createTestDek(kek.ID) - transitKey := createTestTransitKey("test-key", 1, dek.ID) - dekKey := make([]byte, 32) - wrongContext := []byte("wrong context") - - // Create a valid encrypted blob string - nonce := []byte("012345678901") - ciphertext := []byte("encrypted-data") - //nolint:gocritic // intentionally creating new slice for test data - encryptedData := append(nonce, ciphertext...) - blob := transitDomain.EncryptedBlob{ - Version: 1, - Ciphertext: encryptedData, - } - ciphertextStr := blob.String() - - // Setup expectations - mockTransitRepo.EXPECT(). - GetByNameAndVersion(ctx, "test-key", uint(1)). - Return(transitKey, nil). - Once() - - mockDekRepo.EXPECT(). - Get(ctx, dek.ID). - Return(dek, nil). - Once() - - mockKeyManager.EXPECT(). - DecryptDek(dek, kek). - Return(dekKey, nil). - Once() - - mockAeadManager.EXPECT(). - CreateCipher(dekKey, dek.Algorithm). - Return(mockCipher, nil). - Once() - - mockCipher.EXPECT(). - NonceSize(). - Return(12). - Maybe() - - mockCipher.EXPECT(). - Decrypt(ciphertext, nonce, wrongContext). - Return(nil, errors.New("authentication failed")). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - resultBlob, err := uc.Decrypt(ctx, "test-key", ciphertextStr, wrongContext) - - // Assert - assert.Error(t, err) - assert.Nil(t, resultBlob) - assert.True(t, apperrors.Is(err, cryptoDomain.ErrDecryptionFailed)) - mockCipher.AssertExpectations(t) - }) + // 5 bytes < 12-byte nonce + wire := fmt.Sprintf("1:%s", base64.StdEncoding.EncodeToString([]byte{1, 2, 3, 4, 5})) + _, err := uc.Decrypt(ctx, "k", wire, nil) + assert.ErrorIs(t, err, cryptoDomain.ErrDecryptionFailed) } -// TestTransitKeyUseCase_PurgeDeleted tests the PurgeDeleted method of transitKeyUseCase. -func TestTransitKeyUseCase_PurgeDeleted(t *testing.T) { +func TestTransitKeyUseCase_Get(t *testing.T) { + t.Parallel() ctx := context.Background() + uc, _, repo := newTransitKeyUseCase(t) - t.Run("Success_PurgeDeletedKeys", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - olderThanDays := 30 - dryRun := false - expectedDeletedCount := int64(5) - - // Setup expectations - mockTransitRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(expectedDeletedCount, nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) + want := &transitDomain.TransitKey{Name: "k", Version: 1} + repo.EXPECT().GetTransitKey(ctx, "k", uint(0)).Return(want, cryptoDomain.AESGCM, nil) - // Assert - assert.NoError(t, err) - assert.Equal(t, expectedDeletedCount, count) - }) - - t.Run("Success_PurgeDeletedKeys_DryRun", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - olderThanDays := 30 - dryRun := true - expectedDeletedCount := int64(10) - - // Setup expectations - mockTransitRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(expectedDeletedCount, nil). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.NoError(t, err) - assert.Equal(t, expectedDeletedCount, count) - }) - - t.Run("Error_InvalidOlderThanDays", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - olderThanDays := -1 - dryRun := false - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.Error(t, err) - assert.Equal(t, int64(0), count) - assert.Contains(t, err.Error(), "olderThanDays must be a positive number") - }) - t.Run("Error_HardDeleteFails", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - olderThanDays := 30 - dryRun := false - expectedError := errors.New("database error") - - // Setup expectations - mockTransitRepo.EXPECT(). - HardDelete(ctx, mock.AnythingOfType("time.Time"), dryRun). - Return(int64(0), expectedError). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - count, err := uc.PurgeDeleted(ctx, olderThanDays, dryRun) - - // Assert - assert.Error(t, err) - assert.Equal(t, int64(0), count) - assert.Equal(t, expectedError, err) - }) + got, alg, err := uc.Get(ctx, "k", 0) + require.NoError(t, err) + assert.Equal(t, want, got) + assert.Equal(t, cryptoDomain.AESGCM, alg) } -// TestTransitKeyUseCase_Get tests the Get method of transitKeyUseCase. -func TestTransitKeyUseCase_Get(t *testing.T) { +func TestTransitKeyUseCase_Delete(t *testing.T) { + t.Parallel() ctx := context.Background() + uc, _, repo := newTransitKeyUseCase(t) - t.Run("Success_GetTransitKey", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) - - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - expectedKey := createTestTransitKey("test-key", 1, uuid.Must(uuid.NewV7())) - expectedAlg := cryptoDomain.AESGCM - - // Setup expectations - mockTransitRepo.EXPECT(). - GetTransitKey(ctx, "test-key", uint(1)). - Return(expectedKey, expectedAlg, nil). - Once() + repo.EXPECT().Delete(ctx, "k").Return(nil) + assert.NoError(t, uc.Delete(ctx, "k")) +} - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - key, alg, err := uc.Get(ctx, "test-key", 1) +func TestTransitKeyUseCase_PurgeDeleted(t *testing.T) { + t.Parallel() + ctx := context.Background() - // Assert - assert.NoError(t, err) - assert.NotNil(t, key) - assert.Equal(t, expectedKey, key) - assert.Equal(t, expectedAlg, alg) + t.Run("Error_NegativeDays", func(t *testing.T) { + t.Parallel() + uc, _, _ := newTransitKeyUseCase(t) + _, err := uc.PurgeDeleted(ctx, -1, false) + assert.Error(t, err) }) - t.Run("Error_GetTransitKeyNotFound", func(t *testing.T) { - // Setup mocks - mockTxManager := databaseMocks.NewMockTxManager(t) - mockTransitRepo := usecaseMocks.NewMockTransitKeyRepository(t) - mockDekRepo := usecaseMocks.NewMockDekRepository(t) - mockKeyManager := serviceMocks.NewMockKeyManager(t) - mockAeadManager := serviceMocks.NewMockAEADManager(t) + t.Run("Success_DryRun", func(t *testing.T) { + t.Parallel() + uc, _, repo := newTransitKeyUseCase(t) + repo.EXPECT().HardDelete(ctx, mock.Anything, true).Return(int64(2), nil) - // Create test data - kek := createTestKek() - kekChain := createTestKekChain(kek.ID, kek) - defer kekChain.Close() - - // Setup expectations - mockTransitRepo.EXPECT(). - GetTransitKey(ctx, "test-key", uint(1)). - Return(nil, cryptoDomain.Algorithm(""), transitDomain.ErrTransitKeyNotFound). - Once() - - // Execute - uc := NewTransitKeyUseCase( - mockTxManager, mockTransitRepo, mockDekRepo, mockKeyManager, mockAeadManager, kekChain, - ) - key, alg, err := uc.Get(ctx, "test-key", 1) - - // Assert - assert.Error(t, err) - assert.Nil(t, key) - assert.Equal(t, cryptoDomain.Algorithm(""), alg) - assert.True(t, apperrors.Is(err, transitDomain.ErrTransitKeyNotFound)) + n, err := uc.PurgeDeleted(ctx, 30, true) + require.NoError(t, err) + assert.EqualValues(t, 2, n) }) }