Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 49 additions & 9 deletions modules/database/aws_iam_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"errors"

"github.com/GoCodeAlone/modular"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
Expand All @@ -23,20 +24,35 @@ var (
ErrNoUserInfoInDSN = errors.New("no user information in DSN to replace password")
)

// TokenRefreshCallback is called when a token is refreshed
type TokenRefreshCallback func(newToken string, endpoint string)

// IAMTokenProvider defines the interface for AWS IAM token providers
type IAMTokenProvider interface {
GetToken(ctx context.Context, endpoint string) (string, error)
BuildDSNWithIAMToken(ctx context.Context, originalDSN string) (string, error)
StartTokenRefresh(ctx context.Context, endpoint string)
StopTokenRefresh()
// SetTokenRefreshCallback sets a callback to be notified when tokens are refreshed
SetTokenRefreshCallback(callback TokenRefreshCallback)
}

// AWSIAMTokenProvider manages AWS IAM authentication tokens for RDS
type AWSIAMTokenProvider struct {
config *AWSIAMAuthConfig
awsConfig aws.Config
currentToken string
tokenExpiry time.Time
mutex sync.RWMutex
stopChan chan struct{}
refreshDone chan struct{}
refreshStarted bool
config *AWSIAMAuthConfig
awsConfig aws.Config
currentToken string
tokenExpiry time.Time
mutex sync.RWMutex
stopChan chan struct{}
refreshDone chan struct{}
refreshStarted bool
tokenRefreshCallback TokenRefreshCallback
logger modular.Logger // Logger service for error reporting
}

// NewAWSIAMTokenProvider creates a new AWS IAM token provider
func NewAWSIAMTokenProvider(authConfig *AWSIAMAuthConfig) (*AWSIAMTokenProvider, error) {
func NewAWSIAMTokenProvider(authConfig *AWSIAMAuthConfig, logger modular.Logger) (*AWSIAMTokenProvider, error) {
if authConfig == nil || !authConfig.Enabled {
return nil, ErrIAMAuthNotEnabled
}
Expand Down Expand Up @@ -66,6 +82,7 @@ func NewAWSIAMTokenProvider(authConfig *AWSIAMAuthConfig) (*AWSIAMTokenProvider,
awsConfig: awsConfig,
stopChan: make(chan struct{}),
refreshDone: make(chan struct{}),
logger: logger,
}

return provider, nil
Expand Down Expand Up @@ -104,6 +121,22 @@ func (p *AWSIAMTokenProvider) refreshToken(ctx context.Context, endpoint string)
// Tokens are valid for 15 minutes, we refresh earlier to avoid expiry
p.tokenExpiry = time.Now().Add(time.Duration(p.config.TokenRefreshInterval) * time.Second)

// Notify callback if set (this allows database service to recreate connections)
if p.tokenRefreshCallback != nil {
// Call callback in a separate goroutine to avoid blocking token refresh
// Add panic recovery to prevent callback panics from affecting token refresh
go func() {
defer func() {
if r := recover(); r != nil {
// Log the panic but don't fail the token refresh process
// Use the logger service for proper error reporting
p.logger.Error("Database token refresh callback panic recovered", "panic", r)
}
}()
p.tokenRefreshCallback(token, endpoint)
}()
}

return token, nil
}

Expand Down Expand Up @@ -133,6 +166,13 @@ func (p *AWSIAMTokenProvider) StopTokenRefresh() {
<-p.refreshDone
}

// SetTokenRefreshCallback sets a callback to be notified when tokens are refreshed
func (p *AWSIAMTokenProvider) SetTokenRefreshCallback(callback TokenRefreshCallback) {
p.mutex.Lock()
defer p.mutex.Unlock()
p.tokenRefreshCallback = callback
}

// tokenRefreshLoop runs in the background to refresh tokens
func (p *AWSIAMTokenProvider) tokenRefreshLoop(ctx context.Context, endpoint string) {
defer close(p.refreshDone)
Expand Down
21 changes: 16 additions & 5 deletions modules/database/aws_iam_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ import (
"github.com/GoCodeAlone/modular/feeders"
)

// MockLogger implements the Logger interface for testing
type MockLogger struct{}

func (m *MockLogger) Info(msg string, args ...any) {}
func (m *MockLogger) Error(msg string, args ...any) {}
func (m *MockLogger) Warn(msg string, args ...any) {}
func (m *MockLogger) Debug(msg string, args ...any) {}

func TestAWSIAMAuthConfig(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -71,7 +79,8 @@ func TestAWSIAMAuthConfig(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := NewAWSIAMTokenProvider(tt.config)
mockLogger := &MockLogger{}
provider, err := NewAWSIAMTokenProvider(tt.config, mockLogger)
if tt.wantErr {
require.Error(t, err)
require.Nil(t, provider)
Expand Down Expand Up @@ -232,7 +241,7 @@ func TestDatabaseServiceWithAWSIAMAuth(t *testing.T) {

// Skip this test if AWS credentials are not available
// The test will create the service but not actually connect
service, err := NewDatabaseService(config)
service, err := NewDatabaseService(config, &MockLogger{})
if err != nil {
// If AWS config loading fails, skip this test
if strings.Contains(err.Error(), "failed to load AWS config") {
Expand Down Expand Up @@ -261,7 +270,7 @@ func TestDatabaseServiceWithoutAWSIAMAuth(t *testing.T) {
DSN: "postgres://user:password@localhost:5432/mydb",
}

service, err := NewDatabaseService(config)
service, err := NewDatabaseService(config, &MockLogger{})
require.NoError(t, err)
require.NotNil(t, service)

Expand All @@ -286,7 +295,8 @@ func TestAWSIAMTokenProvider_NoDeadlockOnClose(t *testing.T) {
TokenRefreshInterval: 300,
}

provider, err := NewAWSIAMTokenProvider(config)
mockLogger := &MockLogger{}
provider, err := NewAWSIAMTokenProvider(config, mockLogger)
if err != nil {
if strings.Contains(err.Error(), "failed to load AWS config") {
t.Skip("AWS credentials not available, skipping test")
Expand Down Expand Up @@ -318,7 +328,8 @@ func TestAWSIAMTokenProvider_StartStopRefresh(t *testing.T) {
TokenRefreshInterval: 1, // Short interval for testing
}

provider, err := NewAWSIAMTokenProvider(config)
mockLogger := &MockLogger{}
provider, err := NewAWSIAMTokenProvider(config, mockLogger)
if err != nil {
if strings.Contains(err.Error(), "failed to load AWS config") {
t.Skip("AWS credentials not available, skipping test")
Expand Down
4 changes: 4 additions & 0 deletions modules/database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,8 @@ type AWSIAMAuthConfig struct {
// TokenRefreshInterval specifies how often to refresh the IAM token (in seconds)
// Default is 10 minutes (600 seconds), tokens expire after 15 minutes
TokenRefreshInterval int `json:"token_refresh_interval" yaml:"token_refresh_interval" env:"AWS_IAM_AUTH_TOKEN_REFRESH" default:"600"`

// ConnectionTimeout specifies the timeout for database connection tests (in seconds)
// Default is 5 seconds
ConnectionTimeout time.Duration `json:"connection_timeout" yaml:"connection_timeout" env:"AWS_IAM_AUTH_CONNECTION_TIMEOUT" default:"5s"`
}
211 changes: 211 additions & 0 deletions modules/database/coverage_improvement_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
package database

import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
_ "modernc.org/sqlite" // Import sqlite driver for testing
)

// TestOnTokenRefresh tests the onTokenRefresh method in service.go
func TestOnTokenRefresh(t *testing.T) {
// Test early return when db is nil
t.Run("early_return_when_db_is_nil", func(t *testing.T) {
service := &databaseServiceImpl{
config: ConnectionConfig{
Driver: "sqlite",
DSN: "test.db",
},
db: nil, // db is nil
logger: &MockLogger{},
ctx: context.Background(),
}

// This should return early and not panic
service.onTokenRefresh("new-token", "test-endpoint")
// Test passes if no panic occurs
})
}

// TestAWSIAMTokenProviderGetToken tests the GetToken method for cached token scenario
func TestAWSIAMTokenProviderGetToken(t *testing.T) {
t.Run("returns_cached_valid_token", func(t *testing.T) {
provider := &AWSIAMTokenProvider{
currentToken: "cached-token",
tokenExpiry: time.Now().Add(5 * time.Minute),
}

token, err := provider.GetToken(context.Background(), "test-endpoint:5432")

assert.NoError(t, err)
assert.Equal(t, "cached-token", token)
})
}

// TestAWSIAMTokenProviderBuildDSNWithIAMToken tests the BuildDSNWithIAMToken method
func TestAWSIAMTokenProviderBuildDSNWithIAMToken(t *testing.T) {
t.Run("builds_dsn_with_cached_token", func(t *testing.T) {
provider := &AWSIAMTokenProvider{
currentToken: "cached-token",
tokenExpiry: time.Now().Add(5 * time.Minute),
}

dsn, err := provider.BuildDSNWithIAMToken(context.Background(), "postgres://user:password@localhost:5432/db")

assert.NoError(t, err)
assert.Contains(t, dsn, "cached-token")
})
}

// TestAWSIAMTokenProviderSetTokenRefreshCallback tests the SetTokenRefreshCallback method
func TestAWSIAMTokenProviderSetTokenRefreshCallback(t *testing.T) {
provider := &AWSIAMTokenProvider{}

callbackCalled := false
callback := func(token, endpoint string) {
callbackCalled = true
}

provider.SetTokenRefreshCallback(callback)

// Test that callback is stored
assert.NotNil(t, provider.tokenRefreshCallback)

// Test that callback can be called
provider.tokenRefreshCallback("test-token", "test-endpoint")
assert.True(t, callbackCalled)
}

// TestAWSIAMTokenProviderStopTokenRefresh tests the StopTokenRefresh method
func TestAWSIAMTokenProviderStopTokenRefresh(t *testing.T) {
t.Run("stops_token_refresh_when_not_started", func(t *testing.T) {
provider := &AWSIAMTokenProvider{
refreshStarted: false,
}

// This should return early and not block
provider.StopTokenRefresh()
// Test passes if method returns without blocking
})
}

// TestDatabaseServiceImplDB tests the DB method with connection mutex
func TestDatabaseServiceImplDB(t *testing.T) {
service := &databaseServiceImpl{
connMutex: sync.RWMutex{},
db: nil,
}

// This should not panic even when db is nil
db := service.DB()
assert.Nil(t, db)
}

// TestReplacesDSNPasswordFunction tests replaceDSNPassword function edge cases
func TestReplacesDSNPasswordFunction(t *testing.T) {
t.Run("handles_url_style_dsn_without_userinfo", func(t *testing.T) {
dsn := "postgres://localhost:5432/database"
_, err := replaceDSNPassword(dsn, "new-token")

assert.Error(t, err)
assert.Contains(t, err.Error(), "no user information in DSN")
})

t.Run("adds_password_to_key_value_dsn_when_missing", func(t *testing.T) {
dsn := "host=localhost port=5432 user=testuser dbname=testdb"
newDSN, err := replaceDSNPassword(dsn, "new-token")

assert.NoError(t, err)
assert.Contains(t, newDSN, "password=new-token")
})
}

// TestLooksLikeHostnameFunction tests looksLikeHostname function edge cases
func TestLooksLikeHostnameFunction(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"empty_string", "", false},
{"valid_hostname_with_port", "localhost:5432", true},
{"valid_hostname_with_dot", "db.example.com", true},
{"localhost_only", "localhost", true},
{"invalid_with_special_chars", "host!@#$", false},
{"hostname_with_path", "db.example.com/path", true},
{"hostname_with_query", "db.example.com?param=value", true},
{"starts_with_number", "127.0.0.1", true},
{"starts_with_special_char", "!invalid", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := looksLikeHostname(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

// TestIsHexDigitFunction tests isHexDigit function
func TestIsHexDigitFunction(t *testing.T) {
tests := []struct {
name string
input byte
expected bool
}{
{"digit_0", '0', true},
{"digit_9", '9', true},
{"uppercase_A", 'A', true},
{"uppercase_F", 'F', true},
{"lowercase_a", 'a', true},
{"lowercase_f", 'f', true},
{"invalid_G", 'G', false},
{"invalid_g", 'g', false},
{"invalid_special", '!', false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isHexDigit(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

// TestPreprocessDSNForParsingFunction tests preprocessDSNForParsing function edge cases
func TestPreprocessDSNForParsingFunction(t *testing.T) {
t.Run("returns_non_url_dsn_unchanged", func(t *testing.T) {
dsn := "host=localhost port=5432"
result, err := preprocessDSNForParsing(dsn)

assert.NoError(t, err)
assert.Equal(t, dsn, result)
})

t.Run("returns_dsn_without_credentials_unchanged", func(t *testing.T) {
dsn := "postgres://localhost:5432/database"
result, err := preprocessDSNForParsing(dsn)

assert.NoError(t, err)
assert.Equal(t, dsn, result)
})

t.Run("returns_dsn_without_password_unchanged", func(t *testing.T) {
dsn := "postgres://username@localhost:5432/database"
result, err := preprocessDSNForParsing(dsn)

assert.NoError(t, err)
assert.Equal(t, dsn, result)
})

t.Run("returns_already_encoded_dsn_unchanged", func(t *testing.T) {
dsn := "postgres://username:password%21@localhost:5432/database"
result, err := preprocessDSNForParsing(dsn)

assert.NoError(t, err)
assert.Equal(t, dsn, result)
})
}
Loading
Loading