From 9f59f3952288249c0a461ac734f4e42813f136b1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Sep 2025 06:29:05 +0000 Subject: [PATCH 1/4] Initial plan From 0f407c8c3de291fcca246c8271557ef327e9d0df Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Sep 2025 07:03:14 +0000 Subject: [PATCH 2/4] Merge critical IAM token refresh bug fix from upstream Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- modules/database/aws_iam_auth.go | 53 +++++++++-- modules/database/aws_iam_auth_test.go | 12 ++- modules/database/config.go | 4 + modules/database/database_module_bdd_test.go | 6 +- modules/database/debug_test.go | 2 +- modules/database/dsn_special_chars_test.go | 12 +-- modules/database/health_test.go | 9 +- modules/database/module.go | 6 +- modules/database/module_test.go | 19 ++-- modules/database/service.go | 98 +++++++++++++++++++- 10 files changed, 177 insertions(+), 44 deletions(-) diff --git a/modules/database/aws_iam_auth.go b/modules/database/aws_iam_auth.go index c7bcb62b..4cbd03b1 100644 --- a/modules/database/aws_iam_auth.go +++ b/modules/database/aws_iam_auth.go @@ -23,16 +23,30 @@ var ( ErrNoUserInfoInDSN = errors.New("no user information in DSN to replace password") ) +// TokenRefreshCallback is called when a token is refreshed +type TokenRefreshCallback func(newToken string, endpoint string) + +// IAMTokenProvider defines the interface for AWS IAM token providers +type IAMTokenProvider interface { + GetToken(ctx context.Context, endpoint string) (string, error) + BuildDSNWithIAMToken(ctx context.Context, originalDSN string) (string, error) + StartTokenRefresh(ctx context.Context, endpoint string) + StopTokenRefresh() + // SetTokenRefreshCallback sets a callback to be notified when tokens are refreshed + SetTokenRefreshCallback(callback TokenRefreshCallback) +} + // AWSIAMTokenProvider manages AWS IAM authentication tokens for RDS type AWSIAMTokenProvider struct { - config *AWSIAMAuthConfig - awsConfig aws.Config - currentToken string - tokenExpiry time.Time - mutex sync.RWMutex - stopChan chan struct{} - refreshDone chan struct{} - refreshStarted bool + config *AWSIAMAuthConfig + awsConfig aws.Config + currentToken string + tokenExpiry time.Time + mutex sync.RWMutex + stopChan chan struct{} + refreshDone chan struct{} + refreshStarted bool + tokenRefreshCallback TokenRefreshCallback } // NewAWSIAMTokenProvider creates a new AWS IAM token provider @@ -104,6 +118,22 @@ func (p *AWSIAMTokenProvider) refreshToken(ctx context.Context, endpoint string) // Tokens are valid for 15 minutes, we refresh earlier to avoid expiry p.tokenExpiry = time.Now().Add(time.Duration(p.config.TokenRefreshInterval) * time.Second) + // Notify callback if set (this allows database service to recreate connections) + if p.tokenRefreshCallback != nil { + // Call callback in a separate goroutine to avoid blocking token refresh + // Add panic recovery to prevent callback panics from affecting token refresh + go func() { + defer func() { + if r := recover(); r != nil { + // Log the panic but don't fail the token refresh process + // The actual logging will be handled by the callback implementation if available + fmt.Printf("Database token refresh callback panic recovered: %v\n", r) + } + }() + p.tokenRefreshCallback(token, endpoint) + }() + } + return token, nil } @@ -133,6 +163,13 @@ func (p *AWSIAMTokenProvider) StopTokenRefresh() { <-p.refreshDone } +// SetTokenRefreshCallback sets a callback to be notified when tokens are refreshed +func (p *AWSIAMTokenProvider) SetTokenRefreshCallback(callback TokenRefreshCallback) { + p.mutex.Lock() + defer p.mutex.Unlock() + p.tokenRefreshCallback = callback +} + // tokenRefreshLoop runs in the background to refresh tokens func (p *AWSIAMTokenProvider) tokenRefreshLoop(ctx context.Context, endpoint string) { defer close(p.refreshDone) diff --git a/modules/database/aws_iam_auth_test.go b/modules/database/aws_iam_auth_test.go index 6a2782ce..7491bc50 100644 --- a/modules/database/aws_iam_auth_test.go +++ b/modules/database/aws_iam_auth_test.go @@ -14,6 +14,14 @@ import ( "github.com/GoCodeAlone/modular/feeders" ) +// MockLogger implements the Logger interface for testing +type MockLogger struct{} + +func (m *MockLogger) Info(msg string, args ...any) {} +func (m *MockLogger) Error(msg string, args ...any) {} +func (m *MockLogger) Warn(msg string, args ...any) {} +func (m *MockLogger) Debug(msg string, args ...any) {} + func TestAWSIAMAuthConfig(t *testing.T) { tests := []struct { name string @@ -232,7 +240,7 @@ func TestDatabaseServiceWithAWSIAMAuth(t *testing.T) { // Skip this test if AWS credentials are not available // The test will create the service but not actually connect - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) if err != nil { // If AWS config loading fails, skip this test if strings.Contains(err.Error(), "failed to load AWS config") { @@ -261,7 +269,7 @@ func TestDatabaseServiceWithoutAWSIAMAuth(t *testing.T) { DSN: "postgres://user:password@localhost:5432/mydb", } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err) require.NotNil(t, service) diff --git a/modules/database/config.go b/modules/database/config.go index ad99937e..41e5c69b 100644 --- a/modules/database/config.go +++ b/modules/database/config.go @@ -67,4 +67,8 @@ type AWSIAMAuthConfig struct { // TokenRefreshInterval specifies how often to refresh the IAM token (in seconds) // Default is 10 minutes (600 seconds), tokens expire after 15 minutes TokenRefreshInterval int `json:"token_refresh_interval" yaml:"token_refresh_interval" env:"AWS_IAM_AUTH_TOKEN_REFRESH" default:"600"` + + // ConnectionTimeout specifies the timeout for database connection tests (in seconds) + // Default is 5 seconds + ConnectionTimeout time.Duration `json:"connection_timeout" yaml:"connection_timeout" env:"AWS_IAM_AUTH_CONNECTION_TIMEOUT" default:"5s"` } diff --git a/modules/database/database_module_bdd_test.go b/modules/database/database_module_bdd_test.go index 2be73777..3ae91da5 100644 --- a/modules/database/database_module_bdd_test.go +++ b/modules/database/database_module_bdd_test.go @@ -140,7 +140,7 @@ func (ctx *DatabaseBDDTestContext) iHaveAModularApplicationWithDatabaseModuleCon // HACK: Manually set the config and reinitialize connections // This is needed because the instance-aware provider doesn't get our config ctx.module.config = dbConfig - if err := ctx.module.initializeConnections(); err != nil { + if err := ctx.module.initializeConnections(ctx.app); err != nil { return fmt.Errorf("failed to initialize connections manually: %v", err) } @@ -475,7 +475,7 @@ func (ctx *DatabaseBDDTestContext) iHaveADatabaseServiceWithEventObservationEnab // HACK: Manually set the config and reinitialize connections // This is needed because the instance-aware provider doesn't get our config ctx.module.config = dbConfig - if err := ctx.module.initializeConnections(); err != nil { + if err := ctx.module.initializeConnections(ctx.app); err != nil { return fmt.Errorf("failed to initialize connections manually: %v", err) } @@ -716,7 +716,7 @@ func (ctx *DatabaseBDDTestContext) aDatabaseConnectionFailsWithInvalidCredential } // Create a service that will fail to connect - badService, err := NewDatabaseService(badConfig) + badService, err := NewDatabaseService(badConfig, &MockLogger{}) if err != nil { // Driver error - this is before connection, which is what we want ctx.connectionError = err diff --git a/modules/database/debug_test.go b/modules/database/debug_test.go index 77c17d4e..88d95b60 100644 --- a/modules/database/debug_test.go +++ b/modules/database/debug_test.go @@ -12,7 +12,7 @@ func TestDebugTableCreation(t *testing.T) { DSN: ":memory:", } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err) require.NotNil(t, service) diff --git a/modules/database/dsn_special_chars_test.go b/modules/database/dsn_special_chars_test.go index ab3e67f1..5abd52b0 100644 --- a/modules/database/dsn_special_chars_test.go +++ b/modules/database/dsn_special_chars_test.go @@ -30,7 +30,7 @@ func TestSpecialCharacterPasswordDSNParsing(t *testing.T) { DSN: issueExampleDSN, } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err) require.NotNil(t, service) @@ -57,7 +57,7 @@ func TestSpecialCharacterPasswordDSNParsingWithAWSIAM(t *testing.T) { } // Skip this test if AWS credentials are not available - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) if err != nil { // If AWS config loading fails, skip this test if err.Error() == "failed to create AWS IAM token provider: failed to load AWS config: no EC2 IMDS role found, operation error ec2imds: GetMetadata, canceled, context canceled" { @@ -154,7 +154,7 @@ func TestExactFailingScenario(t *testing.T) { DSN: problematicDSN, } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err, "NewDatabaseService should not fail with special characters in DSN") require.NotNil(t, service) @@ -193,7 +193,7 @@ func TestDSNParsingWithoutAWSIAM(t *testing.T) { // No AWSIAMAuth - this should still work } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err, "NewDatabaseService should not fail with special characters in DSN") require.NotNil(t, service) @@ -215,7 +215,7 @@ func TestNonAWSIAMSpecialCharsDSNConnection(t *testing.T) { // No AWS IAM auth configured - this should trigger the bug } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err, "NewDatabaseService should not fail") require.NotNil(t, service) @@ -283,7 +283,7 @@ func TestServiceConnectWithoutPreprocessing(t *testing.T) { // No AWS IAM auth - this is the scenario where the bug occurs } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err, "NewDatabaseService should succeed") // Clean up diff --git a/modules/database/health_test.go b/modules/database/health_test.go index e75d0396..fe7ffdd8 100644 --- a/modules/database/health_test.go +++ b/modules/database/health_test.go @@ -31,7 +31,8 @@ func TestModule_HealthCheck_WithHealthyDatabase(t *testing.T) { } // Initialize the module to establish connections - err := module.initializeConnections() + app := NewMockApplication() + err := module.initializeConnections(app) require.NoError(t, err) // Act: Perform health check @@ -124,7 +125,8 @@ func TestModule_HealthCheck_MultipleConnections(t *testing.T) { } // Initialize the module to establish connections - err := module.initializeConnections() + app := NewMockApplication() + err := module.initializeConnections(app) require.NoError(t, err) // Act: Perform health check @@ -169,7 +171,8 @@ func TestModule_HealthCheck_WithContext(t *testing.T) { } // Initialize the module to establish connections - err := module.initializeConnections() + app := NewMockApplication() + err := module.initializeConnections(app) require.NoError(t, err) // Act: Create a cancelled context diff --git a/modules/database/module.go b/modules/database/module.go index 93f21ca0..08fad71e 100644 --- a/modules/database/module.go +++ b/modules/database/module.go @@ -508,7 +508,7 @@ func (m *Module) Init(app modular.Application) error { }() // Initialize connections - if err := m.initializeConnections(); err != nil { + if err := m.initializeConnections(app); err != nil { return fmt.Errorf("failed to initialize database connections: %w", err) } @@ -695,7 +695,7 @@ func (m *Module) GetService(name string) (DatabaseService, bool) { // initializeConnections initializes database connections based on the module's configuration. // This method processes each configured connection, creates database services, // and establishes initial connectivity to validate the configuration. -func (m *Module) initializeConnections() error { +func (m *Module) initializeConnections(app modular.Application) error { // Initialize database connections if len(m.config.Connections) > 0 { for name, connConfig := range m.config.Connections { @@ -707,7 +707,7 @@ func (m *Module) initializeConnections() error { } // Create the database service and connect - dbService, err := NewDatabaseService(*connConfig) + dbService, err := NewDatabaseService(*connConfig, app.Logger()) if err != nil { return fmt.Errorf("failed to create database service for '%s': %w", name, err) } diff --git a/modules/database/module_test.go b/modules/database/module_test.go index f20d6875..a49946d5 100644 --- a/modules/database/module_test.go +++ b/modules/database/module_test.go @@ -103,13 +103,6 @@ type MockConfigProvider struct { func (m *MockConfigProvider) GetConfig() any { return m.config } -type MockLogger struct{} - -func (l *MockLogger) Debug(msg string, args ...any) {} -func (l *MockLogger) Info(msg string, args ...any) {} -func (l *MockLogger) Warn(msg string, args ...any) {} -func (l *MockLogger) Error(msg string, args ...any) {} - func TestNewModule(t *testing.T) { module := NewModule() assert.NotNil(t, module) @@ -282,7 +275,7 @@ func TestDatabaseServiceFactory(t *testing.T) { DSN: tt.dsn, } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) if tt.shouldSucceed { require.NoError(t, err) assert.NotNil(t, service) @@ -303,7 +296,7 @@ func TestDatabaseService_Operations(t *testing.T) { MaxOpenConnections: 5, // Allow multiple connections for parallel subtests } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err) require.NotNil(t, service) @@ -407,7 +400,7 @@ func TestDatabaseService_ErrorHandling(t *testing.T) { Driver: "sqlite", DSN: ":memory:", } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err) ctx := context.Background() @@ -455,7 +448,7 @@ func TestDatabaseService_ErrorHandling(t *testing.T) { DSN: "test://localhost", } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) require.NoError(t, err) // Service creation should succeed assert.NotNil(t, service) @@ -571,7 +564,7 @@ func BenchmarkDatabaseService_Connect(b *testing.B) { } for i := 0; i < b.N; i++ { - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) if err != nil { b.Skipf("Skipping benchmark - SQLite3 requires CGO: %v", err) return @@ -594,7 +587,7 @@ func BenchmarkDatabaseService_Query(b *testing.B) { DSN: ":memory:", } - service, err := NewDatabaseService(config) + service, err := NewDatabaseService(config, &MockLogger{}) if err != nil { b.Skipf("Skipping benchmark - SQLite3 requires CGO: %v", err) return diff --git a/modules/database/service.go b/modules/database/service.go index 57dc53f3..abe763a2 100644 --- a/modules/database/service.go +++ b/modules/database/service.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log" + "sync" "time" "github.com/GoCodeAlone/modular" @@ -18,6 +19,12 @@ var ( ErrDatabaseNotConnected = errors.New("database not connected") ) +// Constants for database service +const ( + // DefaultConnectionTimeout is the default timeout for database connection tests + DefaultConnectionTimeout = 5 * time.Second +) + // DatabaseService defines the operations that can be performed with a database type DatabaseService interface { // Connect establishes the database connection @@ -97,15 +104,18 @@ type DatabaseService interface { type databaseServiceImpl struct { config ConnectionConfig db *sql.DB - awsTokenProvider *AWSIAMTokenProvider + awsTokenProvider IAMTokenProvider migrationService MigrationService eventEmitter EventEmitter + logger modular.Logger // Logger service for error reporting ctx context.Context cancel context.CancelFunc + endpoint string // Store endpoint for reconnection + connMutex sync.RWMutex // Protect database connection during recreation } // NewDatabaseService creates a new database service from configuration -func NewDatabaseService(config ConnectionConfig) (DatabaseService, error) { +func NewDatabaseService(config ConnectionConfig, logger modular.Logger) (DatabaseService, error) { if config.Driver == "" { return nil, ErrEmptyDriver } @@ -119,6 +129,7 @@ func NewDatabaseService(config ConnectionConfig) (DatabaseService, error) { config: config, ctx: ctx, cancel: cancel, + logger: logger, } // Initialize AWS IAM token provider if enabled @@ -145,12 +156,17 @@ func (s *databaseServiceImpl) Connect() error { return fmt.Errorf("failed to build DSN with IAM token: %w", err) } - // Start background token refresh - endpoint, err := extractEndpointFromDSN(s.config.DSN) + // Extract and store endpoint for token refresh + s.endpoint, err = extractEndpointFromDSN(s.config.DSN) if err != nil { return fmt.Errorf("failed to extract endpoint for token refresh: %w", err) } - s.awsTokenProvider.StartTokenRefresh(s.ctx, endpoint) + + // Set up token refresh callback to recreate connections when tokens are refreshed + s.awsTokenProvider.SetTokenRefreshCallback(s.onTokenRefresh) + + // Start background token refresh + s.awsTokenProvider.StartTokenRefresh(s.ctx, s.endpoint) } else { // Only preprocess when NOT using AWS IAM auth (since AWS IAM auth does its own preprocessing) var err error @@ -221,6 +237,8 @@ func (s *databaseServiceImpl) Close() error { } func (s *databaseServiceImpl) DB() *sql.DB { + s.connMutex.RLock() + defer s.connMutex.RUnlock() return s.db } @@ -437,3 +455,73 @@ func (s *databaseServiceImpl) CreateMigrationsTable(ctx context.Context) error { } return nil } + +// onTokenRefresh is called when IAM token is refreshed to recreate database connections +func (s *databaseServiceImpl) onTokenRefresh(newToken string, endpoint string) { + // Recreate database connection with new token + s.connMutex.Lock() + defer s.connMutex.Unlock() + + if s.db == nil { + return // Connection already closed + } + + // Close existing connections to force pool refresh + oldDB := s.db + + // Build new DSN with refreshed token + newDSN, err := s.awsTokenProvider.BuildDSNWithIAMToken(s.ctx, s.config.DSN) + if err != nil { + // Log error but don't crash the application + s.logger.Error("Failed to build DSN with refreshed IAM token", "error", err, "endpoint", endpoint) + return + } + + // Create new database connection + newDB, err := sql.Open(s.config.Driver, newDSN) + if err != nil { + s.logger.Error("Failed to create new database connection with refreshed token", "error", err, "endpoint", endpoint) + return + } + + // Configure connection pool settings + if s.config.MaxOpenConnections > 0 { + newDB.SetMaxOpenConns(s.config.MaxOpenConnections) + } + if s.config.MaxIdleConnections > 0 { + newDB.SetMaxIdleConns(s.config.MaxIdleConnections) + } + if s.config.ConnectionMaxLifetime > 0 { + newDB.SetConnMaxLifetime(s.config.ConnectionMaxLifetime) + } + if s.config.ConnectionMaxIdleTime > 0 { + newDB.SetConnMaxIdleTime(s.config.ConnectionMaxIdleTime) + } + + // Test the new connection with a timeout + timeout := DefaultConnectionTimeout + if s.config.AWSIAMAuth != nil && s.config.AWSIAMAuth.ConnectionTimeout > 0 { + timeout = s.config.AWSIAMAuth.ConnectionTimeout + } + + testCtx, cancel := context.WithTimeout(s.ctx, timeout) + defer cancel() + + if err := newDB.PingContext(testCtx); err != nil { + s.logger.Error("Failed to ping database with refreshed token", "error", err, "endpoint", endpoint) + newDB.Close() + return + } + + // Replace old connection with new one + s.db = newDB + + // Close old connection in background to avoid blocking + go func() { + if err := oldDB.Close(); err != nil { + s.logger.Warn("Failed to close old database connection", "error", err) + } + }() + + s.logger.Info("Successfully refreshed database connection with new IAM token", "endpoint", endpoint) +} From 5c93381e228ec12fd5c42abf6eb0cfb06eaa2fcf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Sep 2025 07:13:26 +0000 Subject: [PATCH 3/4] Complete upstream merge with test coverage improvements Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- modules/database/coverage_improvement_test.go | 211 ++++++++++++++++++ .../token_refresh_integration_test.go | 77 +++++++ 2 files changed, 288 insertions(+) create mode 100644 modules/database/coverage_improvement_test.go create mode 100644 modules/database/token_refresh_integration_test.go diff --git a/modules/database/coverage_improvement_test.go b/modules/database/coverage_improvement_test.go new file mode 100644 index 00000000..ab8dd37f --- /dev/null +++ b/modules/database/coverage_improvement_test.go @@ -0,0 +1,211 @@ +package database + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + _ "modernc.org/sqlite" // Import sqlite driver for testing +) + +// TestOnTokenRefresh tests the onTokenRefresh method in service.go +func TestOnTokenRefresh(t *testing.T) { + // Test early return when db is nil + t.Run("early_return_when_db_is_nil", func(t *testing.T) { + service := &databaseServiceImpl{ + config: ConnectionConfig{ + Driver: "sqlite", + DSN: "test.db", + }, + db: nil, // db is nil + logger: &MockLogger{}, + ctx: context.Background(), + } + + // This should return early and not panic + service.onTokenRefresh("new-token", "test-endpoint") + // Test passes if no panic occurs + }) +} + +// TestAWSIAMTokenProviderGetToken tests the GetToken method for cached token scenario +func TestAWSIAMTokenProviderGetToken(t *testing.T) { + t.Run("returns_cached_valid_token", func(t *testing.T) { + provider := &AWSIAMTokenProvider{ + currentToken: "cached-token", + tokenExpiry: time.Now().Add(5 * time.Minute), + } + + token, err := provider.GetToken(context.Background(), "test-endpoint:5432") + + assert.NoError(t, err) + assert.Equal(t, "cached-token", token) + }) +} + +// TestAWSIAMTokenProviderBuildDSNWithIAMToken tests the BuildDSNWithIAMToken method +func TestAWSIAMTokenProviderBuildDSNWithIAMToken(t *testing.T) { + t.Run("builds_dsn_with_cached_token", func(t *testing.T) { + provider := &AWSIAMTokenProvider{ + currentToken: "cached-token", + tokenExpiry: time.Now().Add(5 * time.Minute), + } + + dsn, err := provider.BuildDSNWithIAMToken(context.Background(), "postgres://user:password@localhost:5432/db") + + assert.NoError(t, err) + assert.Contains(t, dsn, "cached-token") + }) +} + +// TestAWSIAMTokenProviderSetTokenRefreshCallback tests the SetTokenRefreshCallback method +func TestAWSIAMTokenProviderSetTokenRefreshCallback(t *testing.T) { + provider := &AWSIAMTokenProvider{} + + callbackCalled := false + callback := func(token, endpoint string) { + callbackCalled = true + } + + provider.SetTokenRefreshCallback(callback) + + // Test that callback is stored + assert.NotNil(t, provider.tokenRefreshCallback) + + // Test that callback can be called + provider.tokenRefreshCallback("test-token", "test-endpoint") + assert.True(t, callbackCalled) +} + +// TestAWSIAMTokenProviderStopTokenRefresh tests the StopTokenRefresh method +func TestAWSIAMTokenProviderStopTokenRefresh(t *testing.T) { + t.Run("stops_token_refresh_when_not_started", func(t *testing.T) { + provider := &AWSIAMTokenProvider{ + refreshStarted: false, + } + + // This should return early and not block + provider.StopTokenRefresh() + // Test passes if method returns without blocking + }) +} + +// TestDatabaseServiceImplDB tests the DB method with connection mutex +func TestDatabaseServiceImplDB(t *testing.T) { + service := &databaseServiceImpl{ + connMutex: sync.RWMutex{}, + db: nil, + } + + // This should not panic even when db is nil + db := service.DB() + assert.Nil(t, db) +} + +// TestReplacesDSNPasswordFunction tests replaceDSNPassword function edge cases +func TestReplacesDSNPasswordFunction(t *testing.T) { + t.Run("handles_url_style_dsn_without_userinfo", func(t *testing.T) { + dsn := "postgres://localhost:5432/database" + _, err := replaceDSNPassword(dsn, "new-token") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "no user information in DSN") + }) + + t.Run("adds_password_to_key_value_dsn_when_missing", func(t *testing.T) { + dsn := "host=localhost port=5432 user=testuser dbname=testdb" + newDSN, err := replaceDSNPassword(dsn, "new-token") + + assert.NoError(t, err) + assert.Contains(t, newDSN, "password=new-token") + }) +} + +// TestLooksLikeHostnameFunction tests looksLikeHostname function edge cases +func TestLooksLikeHostnameFunction(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"empty_string", "", false}, + {"valid_hostname_with_port", "localhost:5432", true}, + {"valid_hostname_with_dot", "db.example.com", true}, + {"localhost_only", "localhost", true}, + {"invalid_with_special_chars", "host!@#$", false}, + {"hostname_with_path", "db.example.com/path", true}, + {"hostname_with_query", "db.example.com?param=value", true}, + {"starts_with_number", "127.0.0.1", true}, + {"starts_with_special_char", "!invalid", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := looksLikeHostname(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestIsHexDigitFunction tests isHexDigit function +func TestIsHexDigitFunction(t *testing.T) { + tests := []struct { + name string + input byte + expected bool + }{ + {"digit_0", '0', true}, + {"digit_9", '9', true}, + {"uppercase_A", 'A', true}, + {"uppercase_F", 'F', true}, + {"lowercase_a", 'a', true}, + {"lowercase_f", 'f', true}, + {"invalid_G", 'G', false}, + {"invalid_g", 'g', false}, + {"invalid_special", '!', false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isHexDigit(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestPreprocessDSNForParsingFunction tests preprocessDSNForParsing function edge cases +func TestPreprocessDSNForParsingFunction(t *testing.T) { + t.Run("returns_non_url_dsn_unchanged", func(t *testing.T) { + dsn := "host=localhost port=5432" + result, err := preprocessDSNForParsing(dsn) + + assert.NoError(t, err) + assert.Equal(t, dsn, result) + }) + + t.Run("returns_dsn_without_credentials_unchanged", func(t *testing.T) { + dsn := "postgres://localhost:5432/database" + result, err := preprocessDSNForParsing(dsn) + + assert.NoError(t, err) + assert.Equal(t, dsn, result) + }) + + t.Run("returns_dsn_without_password_unchanged", func(t *testing.T) { + dsn := "postgres://username@localhost:5432/database" + result, err := preprocessDSNForParsing(dsn) + + assert.NoError(t, err) + assert.Equal(t, dsn, result) + }) + + t.Run("returns_already_encoded_dsn_unchanged", func(t *testing.T) { + dsn := "postgres://username:password%21@localhost:5432/database" + result, err := preprocessDSNForParsing(dsn) + + assert.NoError(t, err) + assert.Equal(t, dsn, result) + }) +} \ No newline at end of file diff --git a/modules/database/token_refresh_integration_test.go b/modules/database/token_refresh_integration_test.go new file mode 100644 index 00000000..2037b33d --- /dev/null +++ b/modules/database/token_refresh_integration_test.go @@ -0,0 +1,77 @@ +package database + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestTokenRefreshCallbackIntegration tests the integration of token refresh callback functionality +func TestTokenRefreshCallbackIntegration(t *testing.T) { + // This test verifies that the TokenRefreshCallback interface and implementation work + + // Create a simple mock callback + var callbackExecuted bool + var receivedToken, receivedEndpoint string + + callback := func(token, endpoint string) { + callbackExecuted = true + receivedToken = token + receivedEndpoint = endpoint + } + + // Test that the callback can be set and called + provider := &AWSIAMTokenProvider{} + provider.SetTokenRefreshCallback(callback) + + // Simulate calling the callback (this would normally happen during token refresh) + if provider.tokenRefreshCallback != nil { + provider.tokenRefreshCallback("test-token", "test-endpoint") + } + + // Verify the callback was executed with correct parameters + assert.True(t, callbackExecuted, "Token refresh callback should have been executed") + assert.Equal(t, "test-token", receivedToken, "Callback should receive the correct token") + assert.Equal(t, "test-endpoint", receivedEndpoint, "Callback should receive the correct endpoint") +} + +// TestOnTokenRefreshMethodExists tests that the onTokenRefresh method exists and can be called +func TestOnTokenRefreshMethodExists(t *testing.T) { + // Create a database service implementation + service := &databaseServiceImpl{ + config: ConnectionConfig{ + Driver: "sqlite", + DSN: ":memory:", + }, + logger: &MockLogger{}, + ctx: context.Background(), + db: nil, // Start with nil db to test early return + } + + // This should not panic and should return early since db is nil + service.onTokenRefresh("test-token", "test-endpoint") + + // Test passes if no panic occurs + assert.True(t, true, "onTokenRefresh method executed without panic") +} + +// TestIAMTokenProviderInterface tests that our provider implements the interface +func TestIAMTokenProviderInterface(t *testing.T) { + // This test ensures our AWSIAMTokenProvider properly implements IAMTokenProvider interface + var provider IAMTokenProvider = &AWSIAMTokenProvider{} + + // Test that all interface methods exist + assert.NotNil(t, provider, "Provider should implement IAMTokenProvider interface") + + // Test SetTokenRefreshCallback method exists + callback := func(token, endpoint string) { + // Callback implementation for testing + } + + provider.SetTokenRefreshCallback(callback) + + // Verify callback was set (we can't easily test this without accessing private fields, + // but the fact that the method call succeeds proves the interface is implemented) + assert.True(t, true, "SetTokenRefreshCallback method exists and can be called") +} \ No newline at end of file From a3e1dc56e213319a70ca2029a71f40e1c7852f70 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Sep 2025 08:17:29 +0000 Subject: [PATCH 4/4] Replace fmt.Printf with logger service in AWS IAM token refresh panic recovery Co-authored-by: intel352 <77607+intel352@users.noreply.github.com> --- modules/database/aws_iam_auth.go | 9 ++++++--- modules/database/aws_iam_auth_test.go | 9 ++++++--- modules/database/service.go | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/modules/database/aws_iam_auth.go b/modules/database/aws_iam_auth.go index 4cbd03b1..90516055 100644 --- a/modules/database/aws_iam_auth.go +++ b/modules/database/aws_iam_auth.go @@ -10,6 +10,7 @@ import ( "errors" + "github.com/GoCodeAlone/modular" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" @@ -47,10 +48,11 @@ type AWSIAMTokenProvider struct { refreshDone chan struct{} refreshStarted bool tokenRefreshCallback TokenRefreshCallback + logger modular.Logger // Logger service for error reporting } // NewAWSIAMTokenProvider creates a new AWS IAM token provider -func NewAWSIAMTokenProvider(authConfig *AWSIAMAuthConfig) (*AWSIAMTokenProvider, error) { +func NewAWSIAMTokenProvider(authConfig *AWSIAMAuthConfig, logger modular.Logger) (*AWSIAMTokenProvider, error) { if authConfig == nil || !authConfig.Enabled { return nil, ErrIAMAuthNotEnabled } @@ -80,6 +82,7 @@ func NewAWSIAMTokenProvider(authConfig *AWSIAMAuthConfig) (*AWSIAMTokenProvider, awsConfig: awsConfig, stopChan: make(chan struct{}), refreshDone: make(chan struct{}), + logger: logger, } return provider, nil @@ -126,8 +129,8 @@ func (p *AWSIAMTokenProvider) refreshToken(ctx context.Context, endpoint string) defer func() { if r := recover(); r != nil { // Log the panic but don't fail the token refresh process - // The actual logging will be handled by the callback implementation if available - fmt.Printf("Database token refresh callback panic recovered: %v\n", r) + // Use the logger service for proper error reporting + p.logger.Error("Database token refresh callback panic recovered", "panic", r) } }() p.tokenRefreshCallback(token, endpoint) diff --git a/modules/database/aws_iam_auth_test.go b/modules/database/aws_iam_auth_test.go index 7491bc50..24d72308 100644 --- a/modules/database/aws_iam_auth_test.go +++ b/modules/database/aws_iam_auth_test.go @@ -79,7 +79,8 @@ func TestAWSIAMAuthConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - provider, err := NewAWSIAMTokenProvider(tt.config) + mockLogger := &MockLogger{} + provider, err := NewAWSIAMTokenProvider(tt.config, mockLogger) if tt.wantErr { require.Error(t, err) require.Nil(t, provider) @@ -294,7 +295,8 @@ func TestAWSIAMTokenProvider_NoDeadlockOnClose(t *testing.T) { TokenRefreshInterval: 300, } - provider, err := NewAWSIAMTokenProvider(config) + mockLogger := &MockLogger{} + provider, err := NewAWSIAMTokenProvider(config, mockLogger) if err != nil { if strings.Contains(err.Error(), "failed to load AWS config") { t.Skip("AWS credentials not available, skipping test") @@ -326,7 +328,8 @@ func TestAWSIAMTokenProvider_StartStopRefresh(t *testing.T) { TokenRefreshInterval: 1, // Short interval for testing } - provider, err := NewAWSIAMTokenProvider(config) + mockLogger := &MockLogger{} + provider, err := NewAWSIAMTokenProvider(config, mockLogger) if err != nil { if strings.Contains(err.Error(), "failed to load AWS config") { t.Skip("AWS credentials not available, skipping test") diff --git a/modules/database/service.go b/modules/database/service.go index abe763a2..0dfe474e 100644 --- a/modules/database/service.go +++ b/modules/database/service.go @@ -134,7 +134,7 @@ func NewDatabaseService(config ConnectionConfig, logger modular.Logger) (Databas // Initialize AWS IAM token provider if enabled if config.AWSIAMAuth != nil && config.AWSIAMAuth.Enabled { - tokenProvider, err := NewAWSIAMTokenProvider(config.AWSIAMAuth) + tokenProvider, err := NewAWSIAMTokenProvider(config.AWSIAMAuth, logger) if err != nil { cancel() return nil, fmt.Errorf("failed to create AWS IAM token provider: %w", err)