diff --git a/.test/test/base_connection_plugin_test.go b/.test/test/base_connection_plugin_test.go index 71ac9c2a..2042fb1e 100644 --- a/.test/test/base_connection_plugin_test.go +++ b/.test/test/base_connection_plugin_test.go @@ -141,7 +141,7 @@ func TestBaseConnectionPlugin_InitHostProvider_CallsInitHostProviderFunc(t *test return nil } - err := plugin.InitHostProvider("url", map[string]string{}, nil, initFunc) + err := plugin.InitHostProvider(map[string]string{}, nil, initFunc) assert.True(t, called) assert.NoError(t, err) } @@ -154,6 +154,6 @@ func TestBaseConnectionPlugin_InitHostProvider_PropagatesError(t *testing.T) { return expectedErr } - err := plugin.InitHostProvider("url", map[string]string{}, nil, initFunc) + err := plugin.InitHostProvider(map[string]string{}, nil, initFunc) assert.ErrorIs(t, err, expectedErr) } diff --git a/.test/test/benchmark_plugin.go b/.test/test/benchmark_plugin.go index 0986fbeb..1dcffd53 100644 --- a/.test/test/benchmark_plugin.go +++ b/.test/test/benchmark_plugin.go @@ -93,7 +93,7 @@ func (b *BenchmarkPlugin) NotifyHostListChanged(changes map[string]map[driver_in } func (b *BenchmarkPlugin) InitHostProvider( - initialUrl string, props map[string]string, + props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, initHostProviderFunc func() error) error { b.resources = append(b.resources, "initHostProvider") diff --git a/.test/test/bg_helpers_test.go b/.test/test/bg_helpers_test.go new file mode 100644 index 00000000..fa4c4dc8 --- /dev/null +++ b/.test/test/bg_helpers_test.go @@ -0,0 +1,757 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package test + +import ( + "testing" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/bg" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + "github.com/stretchr/testify/assert" +) + +func TestBlueGreenPhase(t *testing.T) { + tests := []struct { + name string + phase driver_infrastructure.BlueGreenPhase + expectedName string + expectedPhase int + expectedIsActive bool + }{ + { + name: "NotCreated", + phase: driver_infrastructure.NOT_CREATED, + expectedName: "NOT_CREATED", + expectedPhase: 0, + expectedIsActive: false, + }, + { + name: "Created", + phase: driver_infrastructure.CREATED, + expectedName: "CREATED", + expectedPhase: 1, + expectedIsActive: false, + }, + { + name: "Preparation", + phase: driver_infrastructure.PREPARATION, + expectedName: "PREPARATION", + expectedPhase: 2, + expectedIsActive: true, + }, + { + name: "InProgress", + phase: driver_infrastructure.IN_PROGRESS, + expectedName: "IN_PROGRESS", + expectedPhase: 3, + expectedIsActive: true, + }, + { + name: "Post", + phase: driver_infrastructure.POST, + expectedName: "POST", + expectedPhase: 4, + expectedIsActive: true, + }, + { + name: "Completed", + phase: driver_infrastructure.COMPLETED, + expectedName: "COMPLETED", + expectedPhase: 5, + expectedIsActive: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name := tt.phase.GetName() + assert.Equal(t, tt.expectedName, name) + phaseInt := tt.phase.GetPhase() + assert.Equal(t, tt.expectedPhase, phaseInt) + isActiveSwitchoverOrCompleted := tt.phase.IsActiveSwitchoverOrCompleted() + assert.Equal(t, tt.expectedIsActive, isActiveSwitchoverOrCompleted) + }) + } +} + +func TestBlueGreenPhase_IsZero(t *testing.T) { + tests := []struct { + name string + phase driver_infrastructure.BlueGreenPhase + expected bool + }{ + { + name: "NotCreated", + phase: driver_infrastructure.NOT_CREATED, + expected: false, + }, + { + name: "Created", + phase: driver_infrastructure.CREATED, + expected: false, + }, + { + name: "EmptyPhase", + phase: driver_infrastructure.BlueGreenPhase{}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.phase.IsZero() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBlueGreenPhase_Equals(t *testing.T) { + tests := []struct { + name string + phase1 driver_infrastructure.BlueGreenPhase + phase2 driver_infrastructure.BlueGreenPhase + expected bool + }{ + { + name: "SamePhases", + phase1: driver_infrastructure.CREATED, + phase2: driver_infrastructure.CREATED, + expected: true, + }, + { + name: "DifferentPhases", + phase1: driver_infrastructure.CREATED, + phase2: driver_infrastructure.PREPARATION, + expected: false, + }, + { + name: "ZeroPhases", + phase1: driver_infrastructure.BlueGreenPhase{}, + phase2: driver_infrastructure.BlueGreenPhase{}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.phase1.Equals(tt.phase2) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBlueGreenParsePhase(t *testing.T) { + tests := []struct { + name string + statusKey string + expected driver_infrastructure.BlueGreenPhase + }{ + { + name: "EmptyString", + statusKey: "", + expected: driver_infrastructure.NOT_CREATED, + }, + { + name: "Available", + statusKey: "AVAILABLE", + expected: driver_infrastructure.CREATED, + }, + { + name: "SwitchoverInitiated", + statusKey: "SWITCHOVER_INITIATED", + expected: driver_infrastructure.PREPARATION, + }, + { + name: "SwitchoverInProgress", + statusKey: "SWITCHOVER_IN_PROGRESS", + expected: driver_infrastructure.IN_PROGRESS, + }, + { + name: "SwitchoverInPostProcessing", + statusKey: "SWITCHOVER_IN_POST_PROCESSING", + expected: driver_infrastructure.POST, + }, + { + name: "SwitchoverCompleted", + statusKey: "SWITCHOVER_COMPLETED", + expected: driver_infrastructure.COMPLETED, + }, + { + name: "LowercaseStatus", + statusKey: "available", + expected: driver_infrastructure.CREATED, + }, + { + name: "MixedCaseStatus", + statusKey: "Switchover_Initiated", + expected: driver_infrastructure.PREPARATION, + }, + { + name: "UnknownStatus", + statusKey: "UNKNOWN_STATUS", + expected: driver_infrastructure.BlueGreenPhase{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := driver_infrastructure.ParsePhase(tt.statusKey) + if tt.statusKey == "UNKNOWN_STATUS" { + assert.True(t, result.IsZero()) + } else { + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestBlueGreenRole_GetName(t *testing.T) { + tests := []struct { + name string + role driver_infrastructure.BlueGreenRole + expected string + }{ + { + name: "SourceRole", + role: driver_infrastructure.SOURCE, + expected: "SOURCE", + }, + { + name: "TargetRole", + role: driver_infrastructure.TARGET, + expected: "TARGET", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.role.GetName() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBlueGreenRole_GetValue(t *testing.T) { + tests := []struct { + name string + role driver_infrastructure.BlueGreenRole + expected int + }{ + { + name: "SourceRole", + role: driver_infrastructure.SOURCE, + expected: 0, + }, + { + name: "TargetRole", + role: driver_infrastructure.TARGET, + expected: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.role.GetValue() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBlueGreenRole_IsZero(t *testing.T) { + tests := []struct { + name string + role driver_infrastructure.BlueGreenRole + expected bool + }{ + { + name: "SourceRole", + role: driver_infrastructure.SOURCE, + expected: false, + }, + { + name: "TargetRole", + role: driver_infrastructure.TARGET, + expected: false, + }, + { + name: "EmptyRole", + role: driver_infrastructure.BlueGreenRole{}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.role.IsZero() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBlueGreenRole_String(t *testing.T) { + tests := []struct { + name string + role driver_infrastructure.BlueGreenRole + expected string + }{ + { + name: "Source", + role: driver_infrastructure.SOURCE, + expected: "BlueGreenRole [name: SOURCE, value: 0]", + }, + { + name: "Target", + role: driver_infrastructure.TARGET, + expected: "BlueGreenRole [name: TARGET, value: 1]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.role.String() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBlueGreenParseRole(t *testing.T) { + tests := []struct { + name string + roleKey string + expected driver_infrastructure.BlueGreenRole + }{ + { + name: "Source", + roleKey: "BLUE_GREEN_DEPLOYMENT_SOURCE", + expected: driver_infrastructure.SOURCE, + }, + { + name: "Target", + roleKey: "BLUE_GREEN_DEPLOYMENT_TARGET", + expected: driver_infrastructure.TARGET, + }, + { + name: "LowercaseRole", + roleKey: "blue_green_deployment_source", + expected: driver_infrastructure.SOURCE, + }, + { + name: "MixedCaseRole", + roleKey: "Blue_Green_Deployment_Target", + expected: driver_infrastructure.TARGET, + }, + { + name: "UnknownRole", + roleKey: "UNKNOWN_ROLE", + expected: driver_infrastructure.BlueGreenRole{}, + }, + { + name: "EmptyRole", + roleKey: "", + expected: driver_infrastructure.BlueGreenRole{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := driver_infrastructure.ParseRole(tt.roleKey) + if tt.roleKey == "UNKNOWN_ROLE" || tt.roleKey == "" { + assert.True(t, result.IsZero()) + } else { + assert.Equal(t, tt.expected, result) + } + }) + } +} +func TestNewBgStatus(t *testing.T) { + id := "test-bg-id" + phase := driver_infrastructure.CREATED + var connectRoutings []driver_infrastructure.ConnectRouting + var executeRoutings []driver_infrastructure.ExecuteRouting + roleByHost := utils.NewRWMap[driver_infrastructure.BlueGreenRole]() + correspondingHosts := utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]]() + + status := driver_infrastructure.NewBgStatus(id, phase, connectRoutings, executeRoutings, roleByHost, correspondingHosts) + assert.Equal(t, phase, status.GetCurrentPhase()) + assert.Equal(t, connectRoutings, status.GetConnectRoutings()) + assert.Equal(t, executeRoutings, status.GetExecuteRoutings()) + assert.NotNil(t, status.GetCorrespondingHosts()) +} + +func TestBlueGreenStatus_GetRole(t *testing.T) { + roleByHost := utils.NewRWMap[driver_infrastructure.BlueGreenRole]() + + host := &host_info_util.HostInfo{Host: "test.example.com", Port: 5432} + roleByHost.Put("test.example.com", driver_infrastructure.SOURCE) + + status := driver_infrastructure.NewBgStatus( + "test-id", + driver_infrastructure.CREATED, + []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, + roleByHost, + utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + ) + + tests := []struct { + name string + hostInfo *host_info_util.HostInfo + expectedRole driver_infrastructure.BlueGreenRole + expectedOk bool + }{ + { + name: "ExistingHost", + hostInfo: host, + expectedRole: driver_infrastructure.SOURCE, + expectedOk: true, + }, + { + name: "Non-existingHost", + hostInfo: &host_info_util.HostInfo{Host: "nonexistent.example.com", Port: 5432}, + expectedRole: driver_infrastructure.BlueGreenRole{}, + expectedOk: false, + }, + { + name: "NilHost", + hostInfo: nil, + expectedRole: driver_infrastructure.BlueGreenRole{}, + expectedOk: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + role, ok := status.GetRole(tt.hostInfo) + assert.Equal(t, tt.expectedOk, ok) + if tt.expectedOk { + assert.Equal(t, tt.expectedRole, role) + } + }) + } +} + +func TestBlueGreenStatus_IsZero(t *testing.T) { + tests := []struct { + name string + status driver_infrastructure.BlueGreenStatus + expected bool + }{ + { + name: "Zero", + status: driver_infrastructure.BlueGreenStatus{}, + expected: true, + }, + { + name: "NonZeroId", + status: driver_infrastructure.NewBgStatus( + "test-id", + driver_infrastructure.BlueGreenPhase{}, + nil, + nil, + nil, + nil, + ), + expected: false, + }, + { + name: "NonZeroPhase", + status: driver_infrastructure.NewBgStatus( + "", + driver_infrastructure.CREATED, + nil, + nil, + nil, + nil, + ), + expected: false, + }, + { + name: "NonZeroRouting", + status: driver_infrastructure.NewBgStatus( + "", + driver_infrastructure.BlueGreenPhase{}, + []driver_infrastructure.ConnectRouting{}, + nil, + nil, + nil, + ), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.status.IsZero() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBlueGreenStatus_MatchIdPhaseAndLen(t *testing.T) { + baseStatus := driver_infrastructure.NewBgStatus( + "test-id", + driver_infrastructure.CREATED, + []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, + utils.NewRWMap[driver_infrastructure.BlueGreenRole](), + utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + ) + + tests := []struct { + name string + other driver_infrastructure.BlueGreenStatus + expected bool + }{ + { + name: "IdenticalStatus", + other: driver_infrastructure.NewBgStatus( + "test-id", + driver_infrastructure.CREATED, + []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, + utils.NewRWMap[driver_infrastructure.BlueGreenRole](), + utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + ), + expected: true, + }, + { + name: "DifferentID", + other: driver_infrastructure.NewBgStatus( + "different-id", + driver_infrastructure.CREATED, + []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, + utils.NewRWMap[driver_infrastructure.BlueGreenRole](), + utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + ), + expected: false, + }, + { + name: "DifferentPhase", + other: driver_infrastructure.NewBgStatus( + "test-id", + driver_infrastructure.PREPARATION, + []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, + utils.NewRWMap[driver_infrastructure.BlueGreenRole](), + utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + ), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := baseStatus.MatchIdPhaseAndLen(tt.other) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBlueGreenStatus_String(t *testing.T) { + status := driver_infrastructure.NewBgStatus( + "test-bg-id", + driver_infrastructure.CREATED, + []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, + utils.NewRWMap[driver_infrastructure.BlueGreenRole](), + utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + ) + + result := status.String() + + assert.Contains(t, result, "BlueGreenStatus") + assert.Contains(t, result, "test-bg-id") + assert.Contains(t, result, "CREATED") + assert.Contains(t, result, "connect routing") + assert.Contains(t, result, "execute routing") + assert.Contains(t, result, "roleByHost") +} + +func TestBlueGreenResult_String(t *testing.T) { + result := &driver_infrastructure.BlueGreenResult{ + Version: "1.0", + Endpoint: "test.example.com", + Port: 5432, + Role: "SOURCE", + Status: "AVAILABLE", + } + + stringResult := result.String() + + assert.Contains(t, stringResult, "BlueGreenResult") + assert.Contains(t, stringResult, "1.0") + assert.Contains(t, stringResult, "test.example.com") + assert.Contains(t, stringResult, "5432") + assert.Contains(t, stringResult, "SOURCE") + assert.Contains(t, stringResult, "AVAILABLE") +} + +func TestBlueGreenRoutingResultHolder_GetResult(t *testing.T) { + tests := []struct { + name string + holder driver_infrastructure.RoutingResultHolder + expectedValue1 any + expectedValue2 any + expectedOk bool + expectedErr error + }{ + { + name: "ReturnsAllWrappedValues", + holder: driver_infrastructure.RoutingResultHolder{ + WrappedReturnValue: "test-value", + WrappedReturnValue2: 42, + WrappedOk: true, + WrappedErr: nil, + }, + expectedValue1: "test-value", + expectedValue2: 42, + expectedOk: true, + expectedErr: nil, + }, + { + name: "ReturnsError", + holder: driver_infrastructure.RoutingResultHolder{ + WrappedReturnValue: nil, + WrappedReturnValue2: nil, + WrappedOk: false, + WrappedErr: assert.AnError, + }, + expectedValue1: nil, + expectedValue2: nil, + expectedOk: false, + expectedErr: assert.AnError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + value1, value2, ok, err := tt.holder.GetResult() + assert.Equal(t, tt.expectedValue1, value1) + assert.Equal(t, tt.expectedValue2, value2) + assert.Equal(t, tt.expectedOk, ok) + assert.Equal(t, tt.expectedErr, err) + }) + } +} + +func TestBlueGreenRoutingResultHolder_IsPresent(t *testing.T) { + tests := []struct { + name string + holder driver_infrastructure.RoutingResultHolder + expected bool + }{ + { + name: "Empty", + holder: driver_infrastructure.EMPTY_ROUTING_RESULT_HOLDER, + expected: false, + }, + { + name: "NonEmpty", + holder: driver_infrastructure.RoutingResultHolder{ + WrappedReturnValue: "test-value", + }, + expected: true, + }, + { + name: "Error", + holder: driver_infrastructure.RoutingResultHolder{ + WrappedErr: assert.AnError, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.holder.IsPresent() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBlueGreenConstants(t *testing.T) { + assert.Equal(t, driver_infrastructure.BlueGreenIntervalRate(0), driver_infrastructure.BASELINE) + assert.Equal(t, driver_infrastructure.BlueGreenIntervalRate(1), driver_infrastructure.INCREASED) + assert.Equal(t, driver_infrastructure.BlueGreenIntervalRate(2), driver_infrastructure.HIGH) + + assert.Equal(t, "BLUE_GREEN_DEPLOYMENT_SOURCE", driver_infrastructure.BLUE_GREEN_SOURCE) + assert.Equal(t, "BLUE_GREEN_DEPLOYMENT_TARGET", driver_infrastructure.BLUE_GREEN_TARGET) + + assert.Equal(t, "AVAILABLE", driver_infrastructure.AVAILABLE) + assert.Equal(t, "SWITCHOVER_INITIATED", driver_infrastructure.SWITCHOVER_INITIATED) + assert.Equal(t, "SWITCHOVER_IN_PROGRESS", driver_infrastructure.SWITCHOVER_IN_PROGRESS) + assert.Equal(t, "SWITCHOVER_IN_POST_PROCESSING", driver_infrastructure.SWITCHOVER_IN_POST_PROCESSING) + assert.Equal(t, "SWITCHOVER_COMPLETED", driver_infrastructure.SWITCHOVER_COMPLETED) +} + +func TestBlueGreenInterimStatus_IsZero(t *testing.T) { + var nilStatus *bg.BlueGreenInterimStatus + assert.True(t, nilStatus.IsZero(), "Nil status should be zero") + + zeroStatus := &bg.BlueGreenInterimStatus{} + assert.True(t, zeroStatus.IsZero(), "Empty status should be zero") + + nonZeroStatus := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.BlueGreenPhase{}, + nil, nil, false, false, false) + + assert.False(t, nonZeroStatus.IsZero(), "Status with version should not be zero") +} + +func TestBlueGreenInterimStatus_String(t *testing.T) { + status := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.CREATED, + nil, nil, false, false, false) + + result := status.String() + assert.Contains(t, result, "BlueGreenInterimStatus", "String should contain type name") + assert.Contains(t, result, "CREATED", "String should contain phase") + assert.Contains(t, result, "1.0", "String should contain version") + assert.Contains(t, result, "1234", "String should contain port") +} + +func TestBlueGreenInterimStatus_GetCustomHashCode(t *testing.T) { + status1 := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.CREATED, + nil, nil, false, false, false) + + status2 := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.CREATED, + nil, nil, false, false, false) + + status3 := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.PREPARATION, + nil, nil, false, false, false) + + // Same content should produce same hash + hash1 := status1.GetCustomHashCode() + hash2 := status2.GetCustomHashCode() + hash3 := status3.GetCustomHashCode() + assert.Equal(t, hash1, hash2, "Same content should produce same hash") + assert.NotEqual(t, hash1, hash3, "Different content should produce different hash") + assert.NotEqual(t, hash2, hash3, "Different content should produce different hash") +} + +func TestStatusInfo_IsZero(t *testing.T) { + var nilStatus *bg.StatusInfo + assert.True(t, nilStatus.IsZero(), "Nil StatusInfo should be zero") + + zeroStatus := &bg.StatusInfo{} + assert.True(t, zeroStatus.IsZero(), "Empty StatusInfo should be zero") + + nonZeroStatus := bg.NewTestStatusInfo() + + assert.False(t, nonZeroStatus.IsZero(), "StatusInfo with version should not be zero") +} diff --git a/.test/test/bg_plugin_test.go b/.test/test/bg_plugin_test.go new file mode 100644 index 00000000..8922741f --- /dev/null +++ b/.test/test/bg_plugin_test.go @@ -0,0 +1,325 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package test + +import ( + "context" + "database/sql/driver" + "errors" + "testing" + + mock_driver_infrastructure "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/driver_infrastructure" + mock_telemetry "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/util/telemetry" + mock_database_sql_driver "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/database_sql_driver" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/bg" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBlueGreenPluginFactory_GetInstance(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + factory := bg.NewBlueGreenPluginFactory() + require.NotNil(t, factory) + assert.IsType(t, &bg.BlueGreenPluginFactory{}, factory) + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + + plugin, err := factory.GetInstance(mockPluginService, map[string]string{ + property_util.BGD_ID.Name: "", + }) + assert.Nil(t, plugin) + require.NotNil(t, err) + assert.Equal(t, error_util.GetMessage("BlueGreenDeployment.bgIdRequired"), err.Error()) + + plugin, err = factory.GetInstance(mockPluginService, map[string]string{ + property_util.BGD_ID.Name: "test-bg-id", + }) + assert.NoError(t, err) + assert.NotNil(t, plugin) + assert.IsType(t, &bg.BlueGreenPlugin{}, plugin) +} + +func TestBlueGreenPlugin_GetSubscribedMethods(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + props := map[string]string{ + property_util.BGD_ID.Name: "test-bg-id", + } + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + methods := plugin.GetSubscribedMethods() + + assert.NotEmpty(t, methods) + assert.Contains(t, methods, "Conn.Connect") + assert.Contains(t, methods, "Conn.QueryContext") + assert.Contains(t, methods, "Stmt.ExecContext") +} + +func TestBlueGreenPlugin_Connect(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + connectFunc := func(props map[string]string) (driver.Conn, error) { + return mockConn, nil + } + emptyStatus := driver_infrastructure.BlueGreenStatus{} + hostInfo := &host_info_util.HostInfo{Host: "test-host"} + roleByHost := utils.NewRWMap[driver_infrastructure.BlueGreenRole]() + correspondingHosts := utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]]() + connectRouting := bg.NewSubstituteConnectRouting(hostInfo.GetHostAndPort(), driver_infrastructure.SOURCE, hostInfo, nil, nil) + bgStatus := driver_infrastructure.NewBgStatus("test-bg-id", driver_infrastructure.CREATED, []driver_infrastructure.ConnectRouting{connectRouting}, + nil, roleByHost, correspondingHosts) + props := map[string]string{ + property_util.BGD_ID.Name: "test-bg-id", + } + defer (&bg.BlueGreenPluginFactory{}).ClearCaches() + + t.Run("NoBlueGreenStatus", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(emptyStatus, false) + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + conn, err := plugin.Connect(hostInfo, props, true, connectFunc) + + assert.NoError(t, err) + assert.Equal(t, mockConn, conn) + }) + + t.Run("InitialConnectionWithIAM", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(bgStatus, true) + mockPluginService.EXPECT().IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE).Return(true) + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + + conn, err := plugin.Connect(hostInfo, props, true, connectFunc) + + assert.NoError(t, err) + assert.Equal(t, mockConn, conn) + }) + + t.Run("NoMatchingHostRole", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(bgStatus, true) + mockPluginService.EXPECT().IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE).Return(false) + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + + conn, err := plugin.Connect(hostInfo, props, true, connectFunc) + + assert.NoError(t, err) + assert.Equal(t, mockConn, conn) + }) + + t.Run("NoMatchingRoutes", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + roleByHost.Put(hostInfo.GetHost(), driver_infrastructure.TARGET) + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(bgStatus, true) + mockPluginService.EXPECT().IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE).Return(false) + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + + conn, err := plugin.Connect(hostInfo, props, true, connectFunc) + + assert.NoError(t, err) + assert.Equal(t, mockConn, conn) + }) + + t.Run("RoutingConnects", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + roleByHost.Put(hostInfo.GetHost(), driver_infrastructure.SOURCE) + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(bgStatus, true) + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + mockPluginService.EXPECT().IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE).Return(false).AnyTimes() + mockPluginService.EXPECT().Connect(hostInfo, props, gomock.Any()).Return(mockConn, nil) + + conn, err := plugin.Connect(hostInfo, props, true, connectFunc) + + assert.NoError(t, err) + assert.NotNil(t, conn) + }) +} + +func TestBlueGreenPlugin_Execute(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + emptyStatus := driver_infrastructure.BlueGreenStatus{} + props := map[string]string{ + property_util.BGD_ID.Name: "test-bg-id", + } + executeFunc := func() (any, any, bool, error) { + return "result", nil, true, nil + } + roleByHost := utils.NewRWMap[driver_infrastructure.BlueGreenRole]() + correspondingHosts := utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]]() + hostInfo := &host_info_util.HostInfo{Host: "test-host"} + executeRouting := bg.NewSuspendExecuteRouting(hostInfo.GetHostAndPort(), driver_infrastructure.SOURCE, "test-bg-id") + bgStatus := driver_infrastructure.NewBgStatus("test-bg-id", driver_infrastructure.COMPLETED, []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, roleByHost, correspondingHosts) + defer (&bg.BlueGreenPluginFactory{}).ClearCaches() + + t.Run("ClosingMethod", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(emptyStatus, false) + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + result, result2, ok, err := plugin.Execute(mockConn, "Close", executeFunc) + + assert.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, "result", result) + assert.Nil(t, result2) + }) + + t.Run("NoBlueGreenStatus", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(emptyStatus, false) + + result, result2, ok, err := plugin.Execute(mockConn, "Query", executeFunc) + + assert.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, "result", result) + assert.Nil(t, result2) + }) + + t.Run("ErrorGettingCurrentHost", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(bgStatus, true) + mockPluginService.EXPECT().GetCurrentHostInfo().Return((*host_info_util.HostInfo)(nil), errors.New("host error")) + + result, result2, ok, err := plugin.Execute(mockConn, "Query", executeFunc) + + assert.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, "result", result) + assert.Nil(t, result2) + }) + + t.Run("NoMatchingHostRole", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(bgStatus, true) + mockPluginService.EXPECT().GetCurrentHostInfo().Return(hostInfo, nil) + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + + result, result2, ok, err := plugin.Execute(mockConn, "Query", executeFunc) + + assert.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, "result", result) + assert.Nil(t, result2) + }) + + t.Run("NoMatchingRoutes", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + roleByHost.Put(hostInfo.GetHost(), driver_infrastructure.TARGET) + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(bgStatus, true) + mockPluginService.EXPECT().GetCurrentHostInfo().Return(hostInfo, nil) + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + + result, result2, ok, err := plugin.Execute(mockConn, "Query", executeFunc) + + assert.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, "result", result) + assert.Nil(t, result2) + }) + + t.Run("MatchingRoute", func(t *testing.T) { + mockTelemetry := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCtx := mock_telemetry.NewMockTelemetryContext(ctrl) + bgStatusWithRouting := driver_infrastructure.NewBgStatus("test-bg-id", driver_infrastructure.COMPLETED, []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{executeRouting}, roleByHost, correspondingHosts) + + ctxBefore := context.Background() + mockTelemetry.EXPECT(). + OpenTelemetryContext(gomock.Any(), telemetry.NESTED, ctxBefore). + Return(mockTelemetryCtx, ctxBefore).AnyTimes() + mockTelemetryCtx.EXPECT().CloseContext().AnyTimes() + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + roleByHost.Put(hostInfo.GetHost(), driver_infrastructure.SOURCE) + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(ctxBefore).Times(2) + + plugin, err := bg.NewBlueGreenPlugin(mockPluginService, props) + assert.NoError(t, err) + + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(bgStatusWithRouting, true).Times(2) + mockPluginService.EXPECT().GetStatus("test-bg-id").Return(bgStatus, true).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(hostInfo, nil) + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + + result, result2, ok, err := plugin.Execute(mockConn, "Query", executeFunc) + + assert.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, "result", result) + assert.Nil(t, result2) + }) +} diff --git a/.test/test/bg_routing_test.go b/.test/test/bg_routing_test.go new file mode 100644 index 00000000..c64d9703 --- /dev/null +++ b/.test/test/bg_routing_test.go @@ -0,0 +1,483 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package test + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + mock_driver_infrastructure "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/driver_infrastructure" + mock_telemetry "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/util/telemetry" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/bg" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewBaseRouting(t *testing.T) { + hostAndPort := "myapp-prod.cluster-abc123.us-east-1.rds.amazonaws.com:5432" + role := driver_infrastructure.SOURCE + + routing := bg.NewBaseRouting(hostAndPort, role) + + assert.NotNil(t, routing, "Should create BaseRouting instance") +} + +func TestBaseRoutingIsMatch(t *testing.T) { + hostAndPort := "test-host:5432" + role := driver_infrastructure.SOURCE + routing := bg.NewBaseRouting(hostAndPort, role) + + matchingHost, _ := host_info_util.NewHostInfoBuilder(). + SetHost("test-host"). + SetPort(5432). + Build() + + assert.True(t, routing.IsMatch(matchingHost, driver_infrastructure.SOURCE), + "Should match exact host and role") + + assert.False(t, routing.IsMatch(matchingHost, driver_infrastructure.TARGET), + "Should not match different role") + + differentHost, _ := host_info_util.NewHostInfoBuilder(). + SetHost("different-host"). + SetPort(5432). + Build() + + assert.False(t, routing.IsMatch(differentHost, driver_infrastructure.SOURCE), + "Should not match different host") + + emptyRouting := bg.NewBaseRouting("", driver_infrastructure.BlueGreenRole{}) + assert.True(t, emptyRouting.IsMatch(matchingHost, driver_infrastructure.SOURCE), + "Empty routing should match any host and role") +} + +func TestBaseRoutingString(t *testing.T) { + hostAndPort := "test-host:5432" + role := driver_infrastructure.SOURCE + routing := bg.NewBaseRouting(hostAndPort, role) + + result := routing.String() + assert.Contains(t, result, "Routing", "String should contain 'Routing'") + assert.Contains(t, result, "test-host:5432", "String should contain host and port") + assert.Contains(t, result, role.String(), "String should contain role") +} + +func TestRejectConnectRouting(t *testing.T) { + hostAndPort := "test-host:5432" + role := driver_infrastructure.SOURCE + + routing := bg.NewRejectConnectRouting(hostAndPort, role) + assert.NotNil(t, routing, "Should create RejectConnectRouting instance") + + hostInfo, _ := host_info_util.NewHostInfoBuilder(). + SetHost("test-host"). + SetPort(5432). + Build() + + props := make(map[string]string) + + conn, err := routing.Apply(nil, hostInfo, props, true, nil) + assert.Nil(t, conn, "Should return nil connection") + assert.NotNil(t, err, "Should return error") + assert.Contains(t, err.Error(), "in progress", "Error should mention in progress") +} + +func TestSubstituteConnectRouting(t *testing.T) { + hostAndPort := "original-host:5432" + role := driver_infrastructure.SOURCE + + substituteHost, _ := host_info_util.NewHostInfoBuilder(). + SetHost("substitute-host"). + SetPort(5432). + Build() + + routing := bg.NewSubstituteConnectRouting(hostAndPort, role, substituteHost, []*host_info_util.HostInfo{}, nil) + assert.NotNil(t, routing, "Should create SubstituteConnectRouting instance") + + result := routing.String() + assert.Contains(t, result, "SubstituteConnectRouting", "String should contain routing type") + assert.Contains(t, result, "substitute-host:5432", "String should contain substitute host") +} + +func TestSubstituteConnectRoutingApply(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + hostAndPort := "original-host:5432" + role := driver_infrastructure.SOURCE + substituteHost, _ := host_info_util.NewHostInfoBuilder(). + SetHost("substitute-host"). + SetPort(5432). + Build() + mockDriverConn := MockDriverConn{} + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + + t.Run("SubstituteHostIsIp", func(t *testing.T) { + substituteHostIp, _ := host_info_util.NewHostInfoBuilder(). + SetHost("12.34.56.78"). + SetPort(5432). + Build() + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().Connect(substituteHostIp, nil, nil).Return(mockDriverConn, errors.New("ip connect")) + + routing := bg.NewSubstituteConnectRouting(hostAndPort, role, substituteHostIp, []*host_info_util.HostInfo{}, nil) + assert.NotNil(t, routing, "Should create SubstituteConnectRouting instance") + conn, err := routing.Apply(nil, hostInfo, nil, true, mockPluginService) + require.NotNil(t, err, "Should return ip connect error") + assert.Equal(t, "ip connect", err.Error(), "Should return ip connect error") + assert.NotNil(t, conn, "Should create a connection") + }) + + t.Run("IamNotInUse", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().Connect(substituteHost, nil, nil).Return(mockDriverConn, errors.New("no iam connect")) + mockPluginService.EXPECT().IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE).Return(false) + + routing := bg.NewSubstituteConnectRouting(hostAndPort, role, substituteHost, []*host_info_util.HostInfo{}, nil) + assert.NotNil(t, routing, "Should create SubstituteConnectRouting instance") + conn, err := routing.Apply(nil, hostInfo, nil, true, mockPluginService) + require.NotNil(t, err, "Should return direct connect error") + assert.Equal(t, "no iam connect", err.Error(), "Should return direct connect error") + assert.NotNil(t, conn, "Should create a connection") + }) + + t.Run("NilIamHosts", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE).Return(true) + + routing := bg.NewSubstituteConnectRouting(hostAndPort, role, substituteHost, []*host_info_util.HostInfo{}, nil) + assert.NotNil(t, routing, "Should create SubstituteConnectRouting instance") + conn, err := routing.Apply(nil, hostInfo, nil, true, mockPluginService) + require.NotNil(t, err, "Should return requireIamHost error") + assert.Equal(t, error_util.GetMessage("BlueGreenDeployment.requireIamHost"), err.Error()) + assert.Nil(t, conn, "Should fail to create a connection") + }) + + t.Run("UnsuccessfulConnect", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().Connect(gomock.Any(), gomock.Any(), nil).Return(nil, errors.New("unsuccessful connect")) + mockPluginService.EXPECT().IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE).Return(true) + + routing := bg.NewSubstituteConnectRouting(hostAndPort, role, nil, []*host_info_util.HostInfo{hostInfo}, nil) + assert.NotNil(t, routing, "Should create SubstituteConnectRouting instance") + conn, err := routing.Apply(nil, hostInfo, nil, true, mockPluginService) + require.NotNil(t, err, "Should return error") + assert.Equal(t, error_util.GetMessage("BlueGreenDeployment.inProgressCantOpenConnection", ""), err.Error()) + assert.Nil(t, conn, "Should fail to create a connection") + }) + + t.Run("SuccessfulConnect", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().Connect(gomock.Any(), gomock.Any(), nil).Return(mockDriverConn, nil) + mockPluginService.EXPECT().IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE).Return(true) + var iamHost string + iamSuccessfulConnectNotify := func(s string) { + iamHost = s + } + + routing := bg.NewSubstituteConnectRouting(hostAndPort, role, nil, []*host_info_util.HostInfo{nil, hostInfo}, iamSuccessfulConnectNotify) + assert.NotNil(t, routing, "Should create SubstituteConnectRouting instance") + conn, err := routing.Apply(nil, hostInfo, nil, true, mockPluginService) + require.Nil(t, err) + assert.NotNil(t, conn, "Should create a connection") + assert.Equal(t, hostInfo.GetHost(), iamHost) + }) +} + +func TestSuspendConnectRoutingApply(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTelemetry := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCtx := mock_telemetry.NewMockTelemetryContext(ctrl) + + ctxBefore := context.Background() + mockTelemetry.EXPECT(). + OpenTelemetryContext(gomock.Any(), telemetry.NESTED, ctxBefore). + Return(mockTelemetryCtx, ctxBefore).AnyTimes() + mockTelemetryCtx.EXPECT().CloseContext().AnyTimes() + + hostAndPort := "original-host:5432" + role := driver_infrastructure.SOURCE + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + bgId := "test-bg-deployment-123" + bgInProgressStatus := driver_infrastructure.NewBgStatus(bgId, driver_infrastructure.IN_PROGRESS, nil, nil, nil, nil) + bgPostStatus := driver_infrastructure.NewBgStatus(bgId, driver_infrastructure.POST, nil, nil, nil, nil) + routing := bg.NewSuspendConnectRouting(hostAndPort, role, bgId) + props := map[string]string{ + property_util.BG_CONNECT_TIMEOUT_MS.Name: "45", + } + assert.NotNil(t, routing, "Should create SuspendConnectRouting instance") + + t.Run("NilBgStatus", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(driver_infrastructure.BlueGreenStatus{}, false).AnyTimes() + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(ctxBefore).Times(2) + + start := time.Now() + conn, err := routing.Apply(nil, hostInfo, props, true, mockPluginService) + require.NotNil(t, err, "Should return error") + assert.True(t, strings.Contains(err.Error(), "Blue/Green Deployment switchover is completed. Continue with connect call.")) + assert.Nil(t, conn, "Should not create a connection") + + elapsed := time.Since(start) + assert.True(t, elapsed <= 45*time.Millisecond, "Should not sleep for the requested duration") + }) + + t.Run("BgStatusStaysInProgress", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(bgInProgressStatus, true).AnyTimes() + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(gomock.Any()).Times(2) + + start := time.Now() + conn, err := routing.Apply(nil, hostInfo, props, true, mockPluginService) + require.NotNil(t, err, "Should return error") + assert.True(t, strings.Contains(err.Error(), "Blue/Green Deployment switchover is still in progress")) + assert.Nil(t, conn, "Should not create a connection") + + elapsed := time.Since(start) + assert.True(t, elapsed >= 45*time.Millisecond, "Should sleep for at least the requested duration") + assert.True(t, elapsed < 110*time.Millisecond, "Should not sleep much longer than requested. Slept for %d ms.", elapsed.Milliseconds()) + }) + + t.Run("BgStatusChanges", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(bgInProgressStatus, true).Times(2) + mockPluginService.EXPECT().GetStatus(bgId).Return(bgPostStatus, true) + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(ctxBefore).Times(2) + + start := time.Now() + conn, err := routing.Apply(nil, hostInfo, props, true, mockPluginService) + require.NotNil(t, err, "Should return error") + assert.True(t, strings.Contains(err.Error(), "Blue/Green Deployment switchover is completed. Continue with connect call.")) + assert.Nil(t, conn, "Should not create a connection") + + elapsed := time.Since(start) + assert.True(t, elapsed >= 45*time.Millisecond, "Should sleep for at least the requested duration") + assert.True(t, elapsed < 110*time.Millisecond, "Should not sleep much longer than requested. Slept for %d ms.", elapsed.Milliseconds()) + }) +} + +func TestSuspendUntilCorrespondingHostFoundConnectRoutingApply(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTelemetry := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCtx := mock_telemetry.NewMockTelemetryContext(ctrl) + + ctxBefore := context.Background() + mockTelemetry.EXPECT(). + OpenTelemetryContext(gomock.Any(), telemetry.NESTED, ctxBefore). + Return(mockTelemetryCtx, ctxBefore).AnyTimes() + mockTelemetryCtx.EXPECT().CloseContext().AnyTimes() + + hostAndPort := "test-host:5432" + role := driver_infrastructure.SOURCE + bgId := "test-bg-deployment-456" + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(5432).Build() + + routing := bg.NewSuspendUntilCorrespondingHostFoundConnectRouting(hostAndPort, role, bgId) + props := map[string]string{ + property_util.BG_CONNECT_TIMEOUT_MS.Name: "45", + } + assert.NotNil(t, routing, "Should create SuspendUntilCorrespondingNodeFoundConnectRouting instance") + + t.Run("NilBgStatus", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(driver_infrastructure.BlueGreenStatus{}, false).AnyTimes() + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(ctxBefore).Times(2) + + start := time.Now() + conn, err := routing.Apply(nil, hostInfo, props, true, mockPluginService) + require.NotNil(t, err, "Should return error") + assert.True(t, strings.Contains(err.Error(), "Blue/Green Deployment status is completed. Continue with 'connect' call. The call was held for")) + assert.Nil(t, conn, "Should not create a connection") + + elapsed := time.Since(start) + assert.True(t, elapsed <= 45*time.Millisecond, "Should not sleep for the requested duration. Time elapsed: %d.", elapsed.Milliseconds()) + }) + + t.Run("BgStatusCompleted", func(t *testing.T) { + bgCompletedStatus := driver_infrastructure.NewBgStatus(bgId, driver_infrastructure.COMPLETED, nil, nil, nil, nil) + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(bgCompletedStatus, true).AnyTimes() + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(ctxBefore).Times(2) + + start := time.Now() + conn, err := routing.Apply(nil, hostInfo, props, true, mockPluginService) + require.NotNil(t, err, "Should return error") + assert.True(t, strings.Contains(err.Error(), "Blue/Green Deployment status is completed. Continue with 'connect' call. The call was held for")) + assert.Nil(t, conn, "Should not create a connection") + + elapsed := time.Since(start) + assert.True(t, elapsed <= 45*time.Millisecond, "Should not sleep for the requested duration. Time elapsed: %d.", elapsed.Milliseconds()) + }) + + t.Run("TimeoutWaitingForCorrespondingHost", func(t *testing.T) { + correspondingHosts := utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]]() + correspondingHosts.Put("test-host", utils.NewPair(hostInfo, &host_info_util.HostInfo{})) + bgCompletedStatus := driver_infrastructure.NewBgStatus(bgId, driver_infrastructure.POST, nil, nil, nil, correspondingHosts) + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(bgCompletedStatus, true).AnyTimes() + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(gomock.Any()).Times(2) + + start := time.Now() + conn, err := routing.Apply(nil, hostInfo, props, true, mockPluginService) + require.NotNil(t, err, "Should return error") + assert.True(t, strings.Contains(err.Error(), "Blue/Green Deployment switchover is still in progress and a corresponding host for 'test-host' is not found")) + assert.Nil(t, conn, "Should not create a connection") + + elapsed := time.Since(start) + assert.True(t, elapsed >= 45*time.Millisecond, "Should sleep for at least the requested duration") + }) + t.Run("FindCorrespondingHost", func(t *testing.T) { + correspondingHosts := utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]]() + correspondingHosts.Put("test-host", utils.NewPair(hostInfo, hostInfo)) + bgCompletedStatus := driver_infrastructure.NewBgStatus(bgId, driver_infrastructure.POST, nil, nil, nil, correspondingHosts) + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(bgCompletedStatus, true).AnyTimes() + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(gomock.Any()).Times(2) + + start := time.Now() + conn, err := routing.Apply(nil, hostInfo, props, true, mockPluginService) + assert.Nil(t, err, "Should not return error") + assert.Nil(t, conn, "Should not create a connection") + + elapsed := time.Since(start) + assert.True(t, elapsed <= 45*time.Millisecond, "Should not sleep for the requested duration. Time elapsed: %d.", elapsed.Milliseconds()) + }) +} + +func TestSuspendExecuteRoutingApply(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockTelemetry := mock_telemetry.NewMockTelemetryFactory(ctrl) + mockTelemetryCtx := mock_telemetry.NewMockTelemetryContext(ctrl) + + ctxBefore := context.Background() + mockTelemetry.EXPECT(). + OpenTelemetryContext(gomock.Any(), telemetry.NESTED, ctxBefore). + Return(mockTelemetryCtx, ctxBefore).AnyTimes() + mockTelemetryCtx.EXPECT().CloseContext().AnyTimes() + + hostAndPort := "test-host:5432" + role := driver_infrastructure.SOURCE + bgId := "test-bg-id" + bgInProgressStatus := driver_infrastructure.NewBgStatus(bgId, driver_infrastructure.IN_PROGRESS, nil, nil, nil, nil) + bgPostStatus := driver_infrastructure.NewBgStatus(bgId, driver_infrastructure.POST, nil, nil, nil, nil) + + routing := bg.NewSuspendExecuteRouting(hostAndPort, role, bgId) + props := map[string]string{ + property_util.BG_CONNECT_TIMEOUT_MS.Name: "45", + } + assert.NotNil(t, routing, "Should create SuspendExecuteRouting instance") + + methodName := "testMethod" + methodFunc := func() (any, any, bool, error) { + return nil, nil, false, nil + } + + t.Run("NilBgStatus", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(driver_infrastructure.BlueGreenStatus{}, false).AnyTimes() + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(ctxBefore).Times(2) + + start := time.Now() + result := routing.Apply(nil, props, mockPluginService, methodName, methodFunc) + assert.False(t, result.IsPresent(), "Should return empty result when no blue/green status") + + elapsed := time.Since(start) + assert.True(t, elapsed <= 45*time.Millisecond, "Should not sleep for the requested duration") + }) + + t.Run("BgStatusStaysInProgress", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(bgInProgressStatus, true).AnyTimes() + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(gomock.Any()).Times(2) + + start := time.Now() + result := routing.Apply(nil, props, mockPluginService, methodName, methodFunc) + assert.True(t, result.IsPresent(), "Should return result with error") + assert.NotNil(t, result.WrappedErr, "Should return error") + assert.True(t, strings.Contains(result.WrappedErr.Error(), "Blue/Green Deployment switchover is still in progress")) + + elapsed := time.Since(start) + assert.True(t, elapsed >= 45*time.Millisecond, "Should sleep for at least the requested duration") + assert.True(t, elapsed < 110*time.Millisecond, "Should not sleep much longer than requested. Slept for %d ms.", elapsed.Milliseconds()) + }) + + t.Run("BgStatusChanges", func(t *testing.T) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockPluginService.EXPECT().GetStatus(bgId).Return(bgInProgressStatus, true).Times(2) + mockPluginService.EXPECT().GetStatus(bgId).Return(bgPostStatus, true) + mockPluginService.EXPECT().GetTelemetryContext().Return(ctxBefore).Times(1) + mockPluginService.EXPECT().GetTelemetryFactory().Return(mockTelemetry).Times(1) + mockPluginService.EXPECT().SetTelemetryContext(ctxBefore).Times(2) + + start := time.Now() + result := routing.Apply(nil, props, mockPluginService, methodName, methodFunc) + assert.False(t, result.IsPresent(), "Should return empty result when switchover completes") + + elapsed := time.Since(start) + assert.True(t, elapsed >= 45*time.Millisecond, "Should sleep for at least the requested duration") + assert.True(t, elapsed < 110*time.Millisecond, "Should not sleep much longer than requested. Slept for %d ms.", elapsed.Milliseconds()) + }) +} + +func TestBaseRoutingDelay(t *testing.T) { + routing := bg.NewBaseRouting("test-host:5432", driver_infrastructure.SOURCE) + + zeroStatus := driver_infrastructure.BlueGreenStatus{} + + start := time.Now() + routing.Delay(50*time.Millisecond, zeroStatus, nil, "") + elapsed := time.Since(start) + + assert.True(t, elapsed >= 40*time.Millisecond, "Should sleep for at least the requested duration") + assert.True(t, elapsed < 100*time.Millisecond, "Should not sleep much longer than requested") +} diff --git a/.test/test/bg_status_monitor_test.go b/.test/test/bg_status_monitor_test.go new file mode 100644 index 00000000..29adec6c --- /dev/null +++ b/.test/test/bg_status_monitor_test.go @@ -0,0 +1,727 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package test + +import ( + "database/sql/driver" + "errors" + "testing" + "time" + + mock_driver_infrastructure "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/driver_infrastructure" + mock_database_sql_driver "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/database_sql_driver" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/bg" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestNewBlueGreenStatusMonitor(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 300000, + driver_infrastructure.INCREASED: 60000, + driver_infrastructure.HIGH: 5000, + } + + callbackCalled := false + onStatusChangeFunc := func(role driver_infrastructure.BlueGreenRole, interimStatus bg.BlueGreenInterimStatus) { + callbackCalled = true + } + + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + + monitor := bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + "test-bg-id", + hostInfo, + mockPluginService, + map[string]string{}, + statusCheckIntervalMap, + onStatusChangeFunc, + ) + + assert.NotNil(t, monitor) + assert.False(t, callbackCalled) + + // Test getting initial values + topology := monitor.GetCurrentTopology() + assert.NotNil(t, topology) + assert.Len(t, topology, 0) + + ip := monitor.GetConnectedIpAddress() + assert.Equal(t, "", ip) + monitor.SetConnectedIpAddress("localhost") + assert.Equal(t, "localhost", monitor.GetConnectedIpAddress()) + + // Test setting and getting interval rate. + assert.Equal(t, driver_infrastructure.BASELINE, monitor.GetIntervalRate()) + monitor.SetIntervalRate(driver_infrastructure.INCREASED) + assert.Equal(t, driver_infrastructure.INCREASED, monitor.GetIntervalRate()) + monitor.SetIntervalRate(driver_infrastructure.HIGH) + assert.Equal(t, driver_infrastructure.HIGH, monitor.GetIntervalRate()) + + // Test getting ip addresses from hosts + ip = monitor.GetIpAddress("localhost") + assert.NotEmpty(t, ip) + + ip = monitor.GetIpAddress("invalid-host-that-does-not-exist.example.com") + assert.Empty(t, ip) +} + +func TestBlueGreenStatusMonitorDelay(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + + mockPluginService.EXPECT().GetDialect().Return(&driver_infrastructure.MySQLDatabaseDialect{}).AnyTimes() + + monitor := bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + "test-bg-id", + nil, + mockPluginService, + nil, + map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 120, + driver_infrastructure.INCREASED: 60, + driver_infrastructure.HIGH: 45, + }, nil) + + start := time.Now() + monitor.Delay() + elapsed := time.Since(start) + + assert.True(t, elapsed < 40*time.Millisecond, "Should not delay if monitor is set to stop.") + + monitor.SetStop(false) + monitor.SetPanicMode(false) + start = time.Now() + monitor.Delay() + elapsed = time.Since(start) + + assert.True(t, elapsed >= 120*time.Millisecond, "Should delay at least the request time.") + assert.True(t, elapsed <= 200*time.Millisecond, "Should not delay much longer than the requested time.") + + start = time.Now() + monitor.SetPanicMode(true) + monitor.Delay() + elapsed = time.Since(start) + + assert.True(t, elapsed >= 45*time.Millisecond, "Should delay at least the request time.") + assert.True(t, elapsed <= 100*time.Millisecond, "Should not delay much longer than the requested time.") +} + +func TestBlueGreenStatusMonitorUpdateIpAddressFlags(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + hostInfo1, _ := host_info_util.NewHostInfoBuilder().SetHost("host1.example.com").SetPort(3306).Build() + hostInfo2, _ := host_info_util.NewHostInfoBuilder().SetHost("host2.example.com").SetPort(3306).Build() + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 300000, + } + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + + getMonitor := func() *bg.TestBlueGreenStatusMonitor { + return bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + "test-bg-id", + hostInfo1, + mockPluginService, + map[string]string{}, + statusCheckIntervalMap, + nil, + ) + } + + t.Run("CollectedTopologyTrue", func(t *testing.T) { + monitor := getMonitor() + monitor.SetCollectedTopology(true) + monitor.SetAllStartTopologyIpChanged(true) + monitor.SetAllStartTopologyEndpointsRemoved(true) + monitor.SetAllTopologyChanged(true) + + monitor.UpdateIpAddressFlags() + + assert.False(t, monitor.GetAllStartTopologyIpChanged()) + assert.False(t, monitor.GetAllStartTopologyEndpointsRemoved()) + assert.False(t, monitor.GetAllTopologyChanged()) + }) + + t.Run("CollectedTopologyFalseEmptyStartTopology", func(t *testing.T) { + monitor := getMonitor() + monitor.SetStartTopology([]*host_info_util.HostInfo{}) + monitor.SetCollectedTopology(false) + monitor.UpdateIpAddressFlags() + + // Should be false when start topology is empty + assert.False(t, monitor.GetAllStartTopologyIpChanged()) + assert.False(t, monitor.GetAllStartTopologyEndpointsRemoved()) + assert.False(t, monitor.GetAllTopologyChanged()) + }) + + t.Run("CollectedTopologyFalseCollectedIpAddressesFalseNoIpChange", func(t *testing.T) { + monitor := getMonitor() + monitor.GetStartIpAddressesByHostMap().Put("host1.example.com", "192.168.1.1") + monitor.GetStartIpAddressesByHostMap().Put("host2.example.com", "192.168.1.2") + monitor.GetCurrentIpAddressesByHostMap().Put("host1.example.com", "192.168.1.1") + monitor.GetCurrentIpAddressesByHostMap().Put("host2.example.com", "192.168.1.2") + + currentTopology := []*host_info_util.HostInfo{hostInfo1} + monitor.SetStartTopology([]*host_info_util.HostInfo{hostInfo1, hostInfo2}) + monitor.SetCurrentTopology(¤tTopology) + monitor.SetCollectedTopology(false) + monitor.SetCollectedIpAddresses(false) + + monitor.UpdateIpAddressFlags() + + assert.False(t, monitor.GetAllStartTopologyIpChanged(), "IP addresses haven't changed, so AllStartTopologyIpChanged should be false") + assert.False(t, monitor.GetAllStartTopologyEndpointsRemoved(), "Endpoints still have IP addresses, so AllStartTopologyEndpointsRemoved should be false") + assert.False(t, monitor.GetAllTopologyChanged(), "Topology has changed but there is one host in common, so AllTopologyChanged should be false") + }) + + t.Run("CollectedTopologyFalseCollectedIpAddressesFalseAllIpChanged", func(t *testing.T) { + monitor := getMonitor() + currentTopology := []*host_info_util.HostInfo{hostInfo1, hostInfo2} + monitor.SetStartTopology([]*host_info_util.HostInfo{hostInfo1, hostInfo2}) + monitor.SetCurrentTopology(¤tTopology) + monitor.GetStartIpAddressesByHostMap().Put("host1.example.com", "192.168.1.1") + monitor.GetStartIpAddressesByHostMap().Put("host2.example.com", "192.168.1.2") + monitor.GetCurrentIpAddressesByHostMap().Put("host1.example.com", "192.168.2.1") + monitor.GetCurrentIpAddressesByHostMap().Put("host2.example.com", "192.168.2.2") + monitor.SetCollectedTopology(false) + monitor.SetCollectedIpAddresses(false) + + monitor.UpdateIpAddressFlags() + + assert.True(t, monitor.GetAllStartTopologyIpChanged(), "All IP addresses have changed, so AllStartTopologyIpChanged should be true") + assert.False(t, monitor.GetAllStartTopologyEndpointsRemoved(), "Endpoints still have IP addresses, so AllStartTopologyEndpointsRemoved should be false") + assert.False(t, monitor.GetAllTopologyChanged(), "Topology hasn't changed (same hosts), so AllTopologyChanged should be false") + }) + + t.Run("CollectedTopologyFalseAllEndpointsRemoved", func(t *testing.T) { + monitor := getMonitor() + monitor.SetStartTopology([]*host_info_util.HostInfo{hostInfo1, hostInfo2}) + monitor.SetCurrentTopology(nil) + monitor.GetStartIpAddressesByHostMap().Put("host1.example.com", "192.168.1.1") + monitor.GetStartIpAddressesByHostMap().Put("host2.example.com", "192.168.1.2") + monitor.GetCurrentIpAddressesByHostMap().Clear() + monitor.SetCollectedTopology(false) + monitor.SetCollectedIpAddresses(true) + + monitor.UpdateIpAddressFlags() + + assert.False(t, monitor.GetAllStartTopologyIpChanged(), "SetCollectedIpAddresses is true, should stay at initial value") + assert.True(t, monitor.GetAllStartTopologyEndpointsRemoved(), "No more IP addresses, should mark as changed") + assert.False(t, monitor.GetAllTopologyChanged(), "Current topology is empty, should mark as false") + }) + + t.Run("CollectedTopologyFalseAllTopologyChanged", func(t *testing.T) { + monitor := getMonitor() + currentTopology := []*host_info_util.HostInfo{hostInfo1} + monitor.SetStartTopology([]*host_info_util.HostInfo{hostInfo2}) + monitor.SetCurrentTopology(¤tTopology) + monitor.SetCollectedTopology(false) + monitor.SetCollectedIpAddresses(true) + + monitor.UpdateIpAddressFlags() + + assert.False(t, monitor.GetAllStartTopologyIpChanged(), "Empty ip list, should be false") + assert.False(t, monitor.GetAllStartTopologyEndpointsRemoved(), "Empty ip list, should be false") + assert.True(t, monitor.GetAllTopologyChanged(), "Current topology is empty, should mark as false") + }) +} + +func TestBlueGreenStatusMonitorCollectHostIpAddresses(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 300000, + } + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + + monitor := bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + "test-bg-id", + hostInfo, + mockPluginService, + map[string]string{}, + statusCheckIntervalMap, + nil, + ) + + t.Run("EmptyHostNames", func(t *testing.T) { + monitor.GetCurrentIpAddressesByHostMap().Put("host1.example.com", "192.168.1.1") + assert.Equal(t, 1, monitor.GetCurrentIpAddressesByHostMap().Size()) + monitor.GetHostNames().Clear() + assert.Equal(t, 0, monitor.GetHostNames().Size()) + + monitor.CollectHostIpAddresses() + + assert.Equal(t, 0, monitor.GetCurrentIpAddressesByHostMap().Size()) + }) + + t.Run("WithHostNamesCollectedIpAddressesFalse", func(t *testing.T) { + monitor.GetHostNames().Put("localhost", true) + + monitor.GetCurrentIpAddressesByHostMap().Clear() + monitor.GetStartIpAddressesByHostMap().Clear() + monitor.SetCollectedIpAddresses(false) + + monitor.CollectHostIpAddresses() + + assert.Equal(t, 1, monitor.GetCurrentIpAddressesByHostMap().Size()) + + localhostIp, exists := monitor.GetCurrentIpAddressesByHostMap().Get("localhost") + assert.True(t, exists) + assert.NotEmpty(t, localhostIp) + + assert.Equal(t, 0, monitor.GetStartIpAddressesByHostMap().Size()) + }) + + t.Run("WithHostNamesCollectedIpAddressesTrue", func(t *testing.T) { + monitor.GetHostNames().Clear() + monitor.GetHostNames().Put("localhost", true) + monitor.GetCurrentIpAddressesByHostMap().Clear() + monitor.GetStartIpAddressesByHostMap().Clear() + monitor.SetCollectedIpAddresses(true) + + monitor.CollectHostIpAddresses() + + assert.Equal(t, 1, monitor.GetCurrentIpAddressesByHostMap().Size()) + + localhostIp, exists := monitor.GetCurrentIpAddressesByHostMap().Get("localhost") + assert.True(t, exists) + assert.NotEmpty(t, localhostIp) + + assert.Equal(t, 1, monitor.GetStartIpAddressesByHostMap().Size()) + localhostStartIp, exists := monitor.GetStartIpAddressesByHostMap().Get("localhost") + assert.True(t, exists) + assert.Equal(t, localhostIp, localhostStartIp) + }) +} + +func TestBlueGreenStatusMonitorCollectTopology(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockConn.EXPECT().Close().Return(nil).AnyTimes() + var mockDriverConn driver.Conn = mockConn + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + + t.Run("NoHostListProvider", func(t *testing.T) { + monitor, mockDriverDialect, _ := collectTopologySetUp(hostInfo, ctrl) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(true).AnyTimes() + err := monitor.CollectTopology() + assert.NoError(t, err, "Should not error when no host list provider is set.") + }) + + t.Run("NoConn", func(t *testing.T) { + monitor, mockDriverDialect, mockHostListProvider := collectTopologySetUp(hostInfo, ctrl) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(true).AnyTimes() + monitor.SetHostListProvider(mockHostListProvider) + err := monitor.CollectTopology() + assert.NoError(t, err, "Should not error when no connection is set.") + }) + + t.Run("ConnClosed", func(t *testing.T) { + monitor, mockDriverDialect, mockHostListProvider := collectTopologySetUp(hostInfo, ctrl) + monitor.SetHostListProvider(mockHostListProvider) + monitor.SetConnection(&mockDriverConn) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(true).AnyTimes() + err := monitor.CollectTopology() + assert.NoError(t, err, "Should not error when connection is closed.") + }) + + t.Run("ForceRefreshError", func(t *testing.T) { + monitor, mockDriverDialect, mockHostListProvider := collectTopologySetUp(hostInfo, ctrl) + monitor.SetHostListProvider(mockHostListProvider) + monitor.SetConnection(&mockDriverConn) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(false).AnyTimes() + mockHostListProvider.EXPECT().ForceRefresh(gomock.Any()).Return(nil, errors.New("test-error")) + err := monitor.CollectTopology() + assert.Error(t, err, "Should error when ForceRefresh fails.") + }) + + t.Run("ForceRefreshSuccessNoCollection", func(t *testing.T) { + monitor, mockDriverDialect, mockHostListProvider := collectTopologySetUp(hostInfo, ctrl) + monitor.SetHostListProvider(mockHostListProvider) + monitor.SetConnection(&mockDriverConn) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(false).AnyTimes() + mockHostListProvider.EXPECT().ForceRefresh(gomock.Any()).Return([]*host_info_util.HostInfo{hostInfo}, nil) + _, hostInfoInHostNames := monitor.GetHostNames().Get(hostInfo.GetHost()) + assert.False(t, hostInfoInHostNames) + err := monitor.CollectTopology() + assert.NoError(t, err, "Should not error when ForceRefresh returns hosts.") + _, hostInfoInHostNames = monitor.GetHostNames().Get(hostInfo.GetHost()) + assert.False(t, hostInfoInHostNames) + }) + + t.Run("ForceRefreshSuccessCollection", func(t *testing.T) { + monitor, mockDriverDialect, mockHostListProvider := collectTopologySetUp(hostInfo, ctrl) + monitor.SetHostListProvider(mockHostListProvider) + monitor.SetConnection(&mockDriverConn) + monitor.SetCollectedTopology(true) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(false).AnyTimes() + mockHostListProvider.EXPECT().ForceRefresh(gomock.Any()).Return([]*host_info_util.HostInfo{hostInfo}, nil) + _, hostInfoInHostNames := monitor.GetHostNames().Get(hostInfo.GetHost()) + assert.False(t, hostInfoInHostNames) + err := monitor.CollectTopology() + assert.NoError(t, err, "Should not error when ForceRefresh returns hosts.") + _, hostInfoInHostNames = monitor.GetHostNames().Get(hostInfo.GetHost()) + assert.True(t, hostInfoInHostNames) + }) +} + +func collectTopologySetUp(hostInfo *host_info_util.HostInfo, ctrl *gomock.Controller) (*bg.TestBlueGreenStatusMonitor, + *mock_driver_infrastructure.MockDriverDialect, *mock_driver_infrastructure.MockHostListProvider) { + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDriverDialect := mock_driver_infrastructure.NewMockDriverDialect(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockHostListProvider := mock_driver_infrastructure.NewMockHostListProvider(ctrl) + + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 300000, + } + + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetTargetDriverDialect().Return(mockDriverDialect).AnyTimes() + mockPluginService.EXPECT().CreateHostListProvider(gomock.Any()).Return(mockHostListProvider).AnyTimes() + + monitor := bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + "test-bg-id", + hostInfo, + mockPluginService, + map[string]string{}, + statusCheckIntervalMap, + nil, + ) + return monitor, mockDriverDialect, mockHostListProvider +} + +func TestBlueGreenStatusMonitorInitHostListProvider(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 300000, + } + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().CreateHostListProvider(gomock.Any()).Return(&driver_infrastructure.RdsHostListProvider{}).AnyTimes() + + monitor := bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + "test-bg-id", + hostInfo, + mockPluginService, + map[string]string{}, + statusCheckIntervalMap, + nil, + ) + monitor.SetConnectionHostInfoCorrect(false) + assert.Nil(t, monitor.GetHostListProvider(), "HostListProvider should be nil upon monitor creation.") + + monitor.InitHostListProvider() + assert.Nil(t, monitor.GetHostListProvider(), "If ConnectionHostInfo is incorrect, will not initialize") + + monitor.SetConnectionHostInfoCorrect(true) + monitor.InitHostListProvider() + hostListProvider := monitor.GetHostListProvider() + assert.NotNil(t, hostListProvider, "Should initialize when hostInfo is correct") + + monitor.InitHostListProvider() + assert.NotNil(t, monitor.GetHostListProvider(), "Should stay initialized as the same value") + assert.Equal(t, hostListProvider, monitor.GetHostListProvider()) + + monitor.SetConnectionHostInfo(hostInfo) + monitor.SetHostListProvider(nil) + monitor.InitHostListProvider() + assert.NotNil(t, monitor.GetHostListProvider(), "Should initialize with additional values from hostInfo") +} + +func TestBlueGreenStatusMonitorOpenConnection(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockDriverDialect := mock_driver_infrastructure.NewMockDriverDialect(ctrl) + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 300000, + } + + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetTargetDriverDialect().Return(mockDriverDialect).AnyTimes() + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(false).AnyTimes() + + monitor := bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + "test-bg-id", + hostInfo, + mockPluginService, + map[string]string{}, + statusCheckIntervalMap, + nil, + ) + + t.Run("InitialHostInfoSuccess", func(t *testing.T) { + mockPluginService.EXPECT().ForceConnect(gomock.Any(), gomock.Any()).Return(mockConn, nil) + monitor.OpenConnection() + + conn := monitor.GetConnection() + assert.NotNil(t, conn) + assert.False(t, monitor.GetPanicMode()) + }) + + t.Run("InitialHostInfoFailure", func(t *testing.T) { + monitor.SetConnectedIpAddress("localhost") + monitor.SetConnection(nil) + mockPluginService.EXPECT().ForceConnect(gomock.Any(), gomock.Any()).Return(nil, errors.New("test-error")) + monitor.OpenConnection() + + conn := monitor.GetConnection() + assert.Nil(t, conn) + assert.True(t, monitor.GetPanicMode()) + }) + + t.Run("IpAddress", func(t *testing.T) { + monitor.SetConnectedIpAddress("localhost") + monitor.SetUseIpAddress(true) + mockPluginService.EXPECT().ForceConnect(gomock.Any(), gomock.Any()).Return(mockConn, nil) + monitor.OpenConnection() + + conn := monitor.GetConnection() + assert.NotNil(t, conn) + assert.False(t, monitor.GetPanicMode()) + }) +} + +func TestBlueGreenStatusMonitorCloseConnection(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockDriverDialect := mock_driver_infrastructure.NewMockDriverDialect(ctrl) + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(false).AnyTimes() + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 300000, + } + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetTargetDriverDialect().Return(mockDriverDialect).AnyTimes() + + monitor := bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + "test-bg-id", + hostInfo, + mockPluginService, + map[string]string{}, + statusCheckIntervalMap, + nil, + ) + + // Set up a mock connection first + var mockConn driver.Conn = &MockConn{} + monitor.SetConnection(&mockConn) + + conn := monitor.GetConnection() + assert.NotNil(t, conn, "connection should be set to mockConn..") + + monitor.CloseConnection() + + conn = monitor.GetConnection() + assert.Nil(t, conn, "connection should be nil after being closed.") +} + +func TestBlueGreenStatusMonitorResetCollectedData(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 300000, + } + + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + + monitor := bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + "test-bg-id", + hostInfo, + mockPluginService, + map[string]string{}, + statusCheckIntervalMap, + nil, + ) + + monitor.GetHostNames().Put("test-host", true) + monitor.GetHostNames().Put("test-host-2", true) + assert.Equal(t, 2, monitor.GetHostNames().Size(), "Data should be added to HostNames.") + + monitor.ResetCollectedData() + assert.Equal(t, 0, monitor.GetHostNames().Size(), "HostNames should be cleared after reset.") +} + +func TestBlueGreenStatusMonitorCollectStatus(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + hostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("test-host").SetPort(3306).Build() + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: 300000, + } + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + + monitor := bg.NewTestBlueGreenStatusMonitor( + driver_infrastructure.TARGET, + "test-bg-id", + hostInfo, + mockPluginService, + map[string]string{}, + statusCheckIntervalMap, + nil, + ) + + t.Run("NoConnection", func(t *testing.T) { + monitor.CollectStatus() + assert.True(t, monitor.GetCurrentPhase().Equals(driver_infrastructure.NOT_CREATED)) + assert.True(t, monitor.GetPanicMode()) + }) + + t.Run("WithOpenConnectionStatusUnavailable", func(t *testing.T) { + mockDriverDialect := mock_driver_infrastructure.NewMockDriverDialect(ctrl) + mockPluginService.EXPECT().GetTargetDriverDialect().Return(mockDriverDialect).Times(2) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(false).Times(2) + mockDialect.EXPECT().IsBlueGreenStatusAvailable(gomock.Any()).Return(false) + + var mockConn driver.Conn = &MockConn{} + monitor.SetConnection(&mockConn) + + monitor.CollectStatus() + + assert.Equal(t, driver_infrastructure.NOT_CREATED, monitor.GetCurrentPhase(), + "Unavailable status should return NOT_CREATED phase") + assert.True(t, monitor.GetPanicMode()) + }) + + t.Run("WithOpenThenClosedConnectionStatusUnavailable", func(t *testing.T) { + mockDriverDialect := mock_driver_infrastructure.NewMockDriverDialect(ctrl) + mockPluginService.EXPECT().GetTargetDriverDialect().Return(mockDriverDialect).Times(2) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(false).Times(1) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(true).Times(1) + mockDialect.EXPECT().IsBlueGreenStatusAvailable(gomock.Any()).Return(false) + + var mockConn driver.Conn = &MockConn{} + monitor.SetConnection(&mockConn) + + monitor.CollectStatus() + + assert.True(t, monitor.GetCurrentPhase().IsZero(), + "When connection closes unexpectedly, phase should be 0") + assert.True(t, monitor.GetPanicMode()) + }) + + t.Run("StatusAvailableNilStatusInfo", func(t *testing.T) { + monitor.SetCollectedTopology(true) + mockDriverDialect := mock_driver_infrastructure.NewMockDriverDialect(ctrl) + mockPluginService.EXPECT().GetTargetDriverDialect().Return(mockDriverDialect).Times(1) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(false).Times(1) + mockDialect.EXPECT().IsBlueGreenStatusAvailable(gomock.Any()).Return(true) + mockDialect.EXPECT().GetBlueGreenStatus(gomock.Any()).Return([]driver_infrastructure.BlueGreenResult{ + { + Version: "1.0", + Endpoint: "prod-aurora-cluster.cluster-abc123def456.us-east-1.rds.amazonaws.com", + Port: 5432, + Role: "BLUE_GREEN_DEPLOYMENT_SOURCE", + Status: "AVAILABLE", + }, + }) + var mockConn driver.Conn = &MockConn{} + monitor.SetConnection(&mockConn) + + monitor.CollectStatus() + + assert.True(t, monitor.GetCurrentPhase().IsZero(), + "Phase should be 0 when there is no matching status") + assert.True(t, monitor.GetPanicMode()) + }) + + t.Run("StatusAvailable", func(t *testing.T) { + monitor.SetCollectedTopology(true) + mockDriverDialect := mock_driver_infrastructure.NewMockDriverDialect(ctrl) + mockPluginService.EXPECT().GetTargetDriverDialect().Return(mockDriverDialect).Times(1) + mockDriverDialect.EXPECT().IsClosed(gomock.Any()).Return(false).Times(1) + mockDialect.EXPECT().IsBlueGreenStatusAvailable(gomock.Any()).Return(true) + mockPluginService.EXPECT().CreateHostListProvider(gomock.Any()).Return(&driver_infrastructure.RdsHostListProvider{}) + mockDialect.EXPECT().GetBlueGreenStatus(gomock.Any()).Return([]driver_infrastructure.BlueGreenResult{ + { + Version: "1.0", + Endpoint: "prod-aurora-cluster.cluster-abc123def456.us-east-1.rds.amazonaws.com", + Port: 5432, + Role: "BLUE_GREEN_DEPLOYMENT_SOURCE", + Status: "AVAILABLE", + }, + { + Version: "1.1", + Endpoint: "prod-aurora-cluster-target.cluster-xyz789def456.us-east-1.rds.amazonaws.com", + Port: 5432, + Role: "BLUE_GREEN_DEPLOYMENT_TARGET", + Status: "SWITCHOVER_IN_PROGRESS", + }, + }) + var mockConn driver.Conn = &MockConn{} + monitor.SetConnection(&mockConn) + + monitor.CollectStatus() + + assert.Equal(t, driver_infrastructure.IN_PROGRESS, monitor.GetCurrentPhase(), + "Status should match the BlueGreenResult for TARGET") + assert.False(t, monitor.GetPanicMode()) + }) +} diff --git a/.test/test/bg_status_provider_test.go b/.test/test/bg_status_provider_test.go new file mode 100644 index 00000000..77ee9c20 --- /dev/null +++ b/.test/test/bg_status_provider_test.go @@ -0,0 +1,678 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package test + +import ( + "testing" + "time" + + mock_driver_infrastructure "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/bg" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestBlueGreenStatusProviderGetMonitoringProperties(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := &driver_infrastructure.MySQLDatabaseDialect{} + + props := map[string]string{ + property_util.BG_INTERVAL_BASELINE_MS.Name: "300000", + property_util.BG_INTERVAL_INCREASED_MS.Name: "60000", + property_util.BG_INTERVAL_HIGH_MS.Name: "5000", + property_util.BG_SWITCHOVER_TIMEOUT_MS.Name: "600000", + property_util.BG_SUSPEND_NEW_BLUE_CONNECTIONS.Name: "false", + property_util.BG_PROPERTY_PREFIX + "user": "testuser", + property_util.BG_PROPERTY_PREFIX + "password": "testpass", + "normalProp": "normalValue", + } + + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewBlueGreenStatusProvider(mockPluginService, props, "test-bg-id") + assert.NotNil(t, provider) + provider.ClearMonitors() + + monitoringProps := provider.GetMonitoringProperties() + + // BG prefixed properties should be stripped of prefix + assert.Equal(t, "testuser", monitoringProps["user"]) + assert.Equal(t, "testpass", monitoringProps["password"]) + + // Normal properties should remain + assert.Equal(t, "normalValue", monitoringProps["normalProp"]) + + // BG prefixed properties should be removed from monitoring props + assert.NotContains(t, monitoringProps, property_util.BG_PROPERTY_PREFIX+"user") + assert.NotContains(t, monitoringProps, property_util.BG_PROPERTY_PREFIX+"password") +} + +func TestBlueGreenStatusProviderUpdatePhaseForward(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + interimStatus1 := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.CREATED, + nil, nil, false, false, false) + interimStatus2 := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.PREPARATION, + nil, nil, false, false, false) + + provider.UpdatePhase(driver_infrastructure.SOURCE, interimStatus1) + assert.Equal(t, driver_infrastructure.CREATED, provider.GetLatestStatusPhase()) + + provider.UpdatePhase(driver_infrastructure.SOURCE, interimStatus2) + assert.Equal(t, driver_infrastructure.PREPARATION, provider.GetLatestStatusPhase()) +} + +func TestBlueGreenStatusProviderUpdatePhaseRollback(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + interimStatus1 := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.PREPARATION, + nil, nil, false, false, false) + provider.GetInterimStatuses()[driver_infrastructure.SOURCE.GetValue()] = interimStatus1 + provider.UpdatePhase(driver_infrastructure.SOURCE, interimStatus1) + assert.Equal(t, driver_infrastructure.PREPARATION, provider.GetLatestStatusPhase()) + + interimStatus2 := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.CREATED, + nil, nil, false, false, false) + provider.UpdatePhase(driver_infrastructure.SOURCE, interimStatus2) + assert.True(t, provider.GetRollback()) + assert.Equal(t, driver_infrastructure.CREATED, provider.GetLatestStatusPhase()) +} + +func TestBlueGreenStatusProviderGetStatusOfCreated(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + status := provider.GetStatusOfCreated() + + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.CREATED, status.GetCurrentPhase()) + assert.Empty(t, status.GetConnectRoutings()) + assert.Empty(t, status.GetExecuteRoutings()) +} + +func TestBlueGreenStatusProviderGetStatusOfPreparation(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + status := provider.GetStatusOfPreparation() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.PREPARATION, status.GetCurrentPhase()) + assert.NotNil(t, status.GetConnectRoutings()) + + provider.SetPostStatusEndTime(time.Now()) + status = provider.GetStatusOfPreparation() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.COMPLETED, status.GetCurrentPhase(), + "If expired and not rolling back, assign status of COMPLETED") + + provider.SetRollback(true) + status = provider.GetStatusOfPreparation() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.CREATED, status.GetCurrentPhase(), + "If expired and rolling back, assign status of CREATED") +} + +func TestBlueGreenStatusProviderGetStatusOfInProgress(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + t.Run("NilHostIpAddresses", func(t *testing.T) { + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + status := provider.GetStatusOfInProgress() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.IN_PROGRESS, status.GetCurrentPhase()) + assert.Equal(t, 1, len(status.GetConnectRoutings())) + assert.Equal(t, 2, len(status.GetExecuteRoutings())) + }) + + t.Run("NilGetInterimStatuses()", func(t *testing.T) { + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + provider.GetHostIpAddresses().Put("blue-host", "192.168.1.1") + provider.GetHostIpAddresses().Put("green-host", "192.168.1.2") + status := provider.GetStatusOfInProgress() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.IN_PROGRESS, status.GetCurrentPhase()) + assert.Equal(t, 1, len(status.GetConnectRoutings())) + assert.Equal(t, 4, len(status.GetExecuteRoutings())) + }) + + t.Run("SuspendNewBlueConnectionsWhenInProgressFalse", func(t *testing.T) { + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + provider.GetHostIpAddresses().Put("blue-host", "192.168.1.1") + provider.GetHostIpAddresses().Put("green-host", "192.168.1.2") + provider.GetInterimStatuses()[driver_infrastructure.SOURCE.GetValue()] = + bg.NewTestBlueGreenInterimStatus(driver_infrastructure.IN_PROGRESS, + nil, map[string]string{"blue-host": "192.168.1.1"}, false, false, false) + provider.GetInterimStatuses()[driver_infrastructure.TARGET.GetValue()] = + bg.NewTestBlueGreenInterimStatus(driver_infrastructure.IN_PROGRESS, + nil, map[string]string{"green-host": "192.168.1.2"}, false, false, false) + status := provider.GetStatusOfInProgress() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.IN_PROGRESS, status.GetCurrentPhase()) + assert.Equal(t, 3, len(status.GetConnectRoutings())) + assert.Equal(t, 5, len(status.GetExecuteRoutings())) + }) + + t.Run("SuspendNewBlueConnectionsWhenInProgressTrue", func(t *testing.T) { + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, map[string]string{ + property_util.BG_SUSPEND_NEW_BLUE_CONNECTIONS.Name: "true", + }, "test-bg-id") + provider.ClearMonitors() + provider.GetHostIpAddresses().Put("blue-host", "192.168.1.1") + provider.GetHostIpAddresses().Put("green-host", "192.168.1.2") + provider.GetInterimStatuses()[driver_infrastructure.SOURCE.GetValue()] = + bg.NewTestBlueGreenInterimStatus(driver_infrastructure.IN_PROGRESS, + nil, map[string]string{"blue-host": "192.168.1.1"}, false, false, false) + provider.GetInterimStatuses()[driver_infrastructure.TARGET.GetValue()] = + bg.NewTestBlueGreenInterimStatus(driver_infrastructure.IN_PROGRESS, + nil, map[string]string{"green-host": "192.168.1.2"}, false, false, false) + status := provider.GetStatusOfInProgress() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.IN_PROGRESS, status.GetCurrentPhase()) + assert.Equal(t, 6, len(status.GetConnectRoutings())) + assert.Equal(t, 6, len(status.GetExecuteRoutings())) + }) + + t.Run("PastEndTime", func(t *testing.T) { + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + provider.SetPostStatusEndTime(time.Now()) + status := provider.GetStatusOfInProgress() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.COMPLETED, status.GetCurrentPhase(), + "If expired and not rolling back, assign status of COMPLETED") + }) + + t.Run("PastEndTimeRollback", func(t *testing.T) { + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + provider.SetPostStatusEndTime(time.Now()) + provider.SetRollback(true) + status := provider.GetStatusOfInProgress() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.CREATED, status.GetCurrentPhase(), + "If expired and rolling back, assign status of CREATED") + }) +} + +func TestBlueGreenStatusProviderGetStatusOfPost(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + status := provider.GetStatusOfPost() + + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.POST, status.GetCurrentPhase()) + assert.NotNil(t, status.GetConnectRoutings()) + assert.Empty(t, status.GetExecuteRoutings()) + + provider.SetPostStatusEndTime(time.Now()) + status = provider.GetStatusOfPost() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.COMPLETED, status.GetCurrentPhase(), + "If expired and not rolling back, assign status of COMPLETED") + + provider.SetRollback(true) + status = provider.GetStatusOfPost() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.CREATED, status.GetCurrentPhase(), + "If expired and rolling back, assign status of CREATED") +} + +func TestBlueGreenStatusProviderGetStatusOfCompleted(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + provider.SetBlueDnsUpdateCompleted(true) + provider.SetGreenDnsRemoved(true) + + status := provider.GetStatusOfCompleted() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.COMPLETED, status.GetCurrentPhase()) + assert.Empty(t, status.GetConnectRoutings()) + assert.Empty(t, status.GetExecuteRoutings()) + + provider.SetBlueDnsUpdateCompleted(false) + status = provider.GetStatusOfCompleted() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.POST, status.GetCurrentPhase(), + "DNS not updated to reflect completion yet, mark as POST") + + provider.SetBlueDnsUpdateCompleted(true) + provider.SetGreenDnsRemoved(false) + status = provider.GetStatusOfCompleted() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.POST, status.GetCurrentPhase(), + "DNS not updated to reflect completion yet, mark as POST") + + provider.SetPostStatusEndTime(time.Now()) + status = provider.GetStatusOfPost() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.COMPLETED, status.GetCurrentPhase(), + "If expired and not rolling back, assign status of COMPLETED") + + provider.SetRollback(true) + status = provider.GetStatusOfPost() + assert.Equal(t, "test-bg-id", status.GetBgId()) + assert.Equal(t, driver_infrastructure.CREATED, status.GetCurrentPhase(), + "If expired and rolling back, assign status of CREATED") +} + +func TestBlueGreenStatusProviderRegisterIamHost(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + provider.RegisterIamHost("green-host", "blue-host") + assert.True(t, provider.IsAlreadySuccessfullyConnected("green-host", "blue-host")) + assert.False(t, provider.IsAlreadySuccessfullyConnected("green-host", "other-host")) +} + +func TestBlueGreenStatusProviderGetWriterHost(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + // Test with empty interim status + writerHost := provider.GetWriterHost(driver_infrastructure.SOURCE) + assert.Nil(t, writerHost) + + // Test with topology containing writer + writerHostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("writer-host").SetRole(host_info_util.WRITER).Build() + readerHostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("reader-host").SetRole(host_info_util.READER).Build() + + provider.GetInterimStatuses()[driver_infrastructure.SOURCE.GetValue()] = bg.NewTestBlueGreenInterimStatus(driver_infrastructure.BlueGreenPhase{}, + []*host_info_util.HostInfo{writerHostInfo, readerHostInfo}, nil, false, false, false) + + writerHost = provider.GetWriterHost(driver_infrastructure.SOURCE) + assert.NotNil(t, writerHost) + assert.Equal(t, "writer-host", writerHost.GetHost()) +} + +func TestBlueGreenStatusProviderGetReaderHosts(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + // Test with empty interim status + readerHosts := provider.GetReaderHosts(driver_infrastructure.SOURCE) + assert.Nil(t, readerHosts) + + // Test with topology containing readers + writerHostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost("writer-host").SetRole(host_info_util.WRITER).Build() + readerHostInfo1, _ := host_info_util.NewHostInfoBuilder().SetHost("reader-host-1").SetRole(host_info_util.READER).Build() + readerHostInfo2, _ := host_info_util.NewHostInfoBuilder().SetHost("reader-host-2").SetRole(host_info_util.READER).Build() + + provider.GetInterimStatuses()[driver_infrastructure.SOURCE.GetValue()] = bg.NewTestBlueGreenInterimStatus(driver_infrastructure.BlueGreenPhase{}, + []*host_info_util.HostInfo{writerHostInfo, readerHostInfo1, readerHostInfo2}, nil, false, false, false) + + readerHosts = provider.GetReaderHosts(driver_infrastructure.SOURCE) + assert.NotNil(t, readerHosts) + assert.Len(t, readerHosts, 2) + assert.Equal(t, "reader-host-1", readerHosts[0].Host) + assert.Equal(t, "reader-host-2", readerHosts[1].Host) +} + +func TestBlueGreenStatusProviderStoreBlueDnsUpdateTime(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + // Test storing blue DNS update time + provider.StoreBlueDnsUpdateTime() + + // Verify the time was stored + phaseTime, exists := provider.GetPhaseTimeNano().Get("Blue DNS updated") + assert.True(t, exists) + assert.True(t, time.Since(phaseTime.Timestamp) < time.Second) +} + +func TestBlueGreenStatusProviderStoreGreenDnsRemoveTime(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + // Test storing green DNS remove time + provider.StoreGreenDnsRemoveTime() + + // Verify the time was stored + phaseTime, exists := provider.GetPhaseTimeNano().Get("Green DNS removed") + assert.True(t, exists) + assert.True(t, time.Since(phaseTime.Timestamp) < time.Second) +} + +func TestBlueGreenStatusProviderStoreGreenTopologyChangeTime(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + // Test storing green topology change time + provider.StoreGreenTopologyChangeTime() + + // Verify the time was stored + phaseTime, exists := provider.GetPhaseTimeNano().Get("Green topology changed") + assert.True(t, exists) + assert.True(t, time.Since(phaseTime.Timestamp) < time.Second) +} + +func TestBlueGreenStatusProviderStartSwitchoverTimer(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + props := map[string]string{ + property_util.BG_SWITCHOVER_TIMEOUT_MS.Name: "5000", // 5 seconds + } + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, props, "test-bg-id") + provider.ClearMonitors() + + // Initially, no timer should be set + assert.True(t, provider.GetPostStatusEndTime().IsZero()) + + // Start the switchover timer + provider.StartSwitchoverTimer() + + // Verify timer was set + assert.False(t, provider.GetPostStatusEndTime().IsZero()) + assert.True(t, provider.GetPostStatusEndTime().After(time.Now())) + + // Calling again should not change the timer + originalTime := provider.GetPostStatusEndTime() + provider.StartSwitchoverTimer() + assert.Equal(t, originalTime, provider.GetPostStatusEndTime()) +} + +func TestBlueGreenStatusProviderUpdateDnsFlags(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + // Test blue DNS update completion + interimStatus := bg.NewTestBlueGreenInterimStatus(driver_infrastructure.BlueGreenPhase{}, + nil, nil, true, false, false) + + assert.False(t, provider.GetBlueDnsUpdateCompleted()) + provider.UpdateDnsFlags(driver_infrastructure.SOURCE, interimStatus) + assert.True(t, provider.GetBlueDnsUpdateCompleted()) + + // Test green DNS removal + interimStatus = bg.NewTestBlueGreenInterimStatus(driver_infrastructure.BlueGreenPhase{}, + nil, nil, false, true, false) + + assert.False(t, provider.GetGreenDnsRemoved()) + provider.UpdateDnsFlags(driver_infrastructure.TARGET, interimStatus) + assert.True(t, provider.GetGreenDnsRemoved()) + + // Test green topology change + interimStatus = bg.NewTestBlueGreenInterimStatus(driver_infrastructure.BlueGreenPhase{}, + nil, nil, false, false, true) + + assert.False(t, provider.GetGreenTopologyChanged()) + provider.UpdateDnsFlags(driver_infrastructure.TARGET, interimStatus) + assert.True(t, provider.GetGreenTopologyChanged()) +} + +func TestBlueGreenStatusProviderAddSubstituteBlueWithIpAddressConnectRouting(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + // Set up test data + blueHost, _ := host_info_util.NewHostInfoBuilder().SetHost("blue-host").SetPort(3306).Build() + provider.GetRoleByHost().Put("blue-host", driver_infrastructure.SOURCE) + provider.GetCorrespondingHosts().Put("blue-host", utils.NewPair(blueHost, blueHost)) + provider.GetHostIpAddresses().Put("blue-host", "192.168.1.1") + provider.GetInterimStatuses()[driver_infrastructure.SOURCE.GetValue()] = bg.NewTestBlueGreenInterimStatus(driver_infrastructure.BlueGreenPhase{}, + nil, nil, false, false, false) + + // Test the method + routing := provider.AddSubstituteBlueWithIpAddressConnectRouting() + + // Verify routing was created + assert.NotEmpty(t, routing) + // Should have routing for both host and host:port + assert.Len(t, routing, 2) +} + +func TestBlueGreenStatusProviderCreatePostRouting(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + // Test with no DNS updates completed + provider.SetBlueDnsUpdateCompleted(false) + provider.SetAllGreenHostsChangedName(false) + + // Set up test data + blueHost, _ := host_info_util.NewHostInfoBuilder().SetHost("blue-host").SetPort(3306).Build() + greenHost, _ := host_info_util.NewHostInfoBuilder().SetHost("green-host").SetPort(3306).Build() + + provider.GetRoleByHost().Put("blue-host", driver_infrastructure.SOURCE) + provider.GetCorrespondingHosts().Put("blue-host", utils.NewPair(blueHost, greenHost)) + provider.GetHostIpAddresses().Put("green-host", "192.168.1.2") + + routing := provider.CreatePostRouting() + assert.NotEmpty(t, routing) + + // Test with DNS updates completed + provider.SetBlueDnsUpdateCompleted(true) + provider.SetAllGreenHostsChangedName(true) + + routing = provider.CreatePostRouting() + assert.Empty(t, routing) // Should be empty when DNS updates are completed +} + +func TestBlueGreenStatusProviderResetContextWhenCompleted(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + // Set up completed state + provider.SetSummaryStatus(driver_infrastructure.NewBgStatus("test-bg-id", driver_infrastructure.COMPLETED, nil, nil, nil, nil)) + provider.SetRollback(false) + provider.GetPhaseTimeNano().Put("test-phase", bg.PhaseTimeInfo{ + Timestamp: time.Now(), + Phase: driver_infrastructure.COMPLETED, + }) + + // Set some state that should be reset + provider.SetGreenDnsRemoved(true) + provider.SetGreenTopologyChanged(true) + provider.SetAllGreenHostsChangedName(true) + + // Test reset + provider.ResetContextWhenCompleted() + + // Verify state was reset + assert.False(t, provider.GetRollback()) + assert.False(t, provider.GetGreenDnsRemoved()) + assert.False(t, provider.GetGreenTopologyChanged()) + assert.False(t, provider.GetAllGreenHostsChangedName()) + assert.Equal(t, 0, provider.GetPhaseTimeNano().Size()) +} + +func TestBlueGreenStatusProviderPutIfAbsentPhaseTime(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockPluginService := mock_driver_infrastructure.NewMockPluginService(ctrl) + mockDialect := mock_driver_infrastructure.NewMockBlueGreenDialect(ctrl) + mockPluginService.EXPECT().GetDialect().Return(mockDialect).AnyTimes() + mockPluginService.EXPECT().GetCurrentHostInfo().Return(&host_info_util.HostInfo{Host: "test-host"}, nil).AnyTimes() + + provider := bg.NewTestBlueGreenStatusProvider(mockPluginService, nil, "test-bg-id") + provider.ClearMonitors() + + phase := driver_infrastructure.CREATED + + // Test normal case + provider.PutIfAbsentPhaseTime("test-phase", phase) + phaseTime, exists := provider.GetPhaseTimeNano().Get("test-phase") + assert.True(t, exists) + assert.Equal(t, phase, phaseTime.Phase) + + // Test rollback case + provider.SetRollback(true) + provider.PutIfAbsentPhaseTime("rollback-phase", phase) + phaseTime, exists = provider.GetPhaseTimeNano().Get("rollback-phase (rollback)") + assert.True(t, exists) + assert.Equal(t, phase, phaseTime.Phase) + + // Test that existing entries are not overwritten + originalTime := phaseTime.Timestamp + time.Sleep(1 * time.Millisecond) // Ensure different timestamp + provider.PutIfAbsentPhaseTime("rollback-phase", phase) + phaseTime, _ = provider.GetPhaseTimeNano().Get("rollback-phase (rollback)") + assert.Equal(t, originalTime, phaseTime.Timestamp) // Should not change +} diff --git a/.test/test/default_plugin_test.go b/.test/test/default_plugin_test.go index 5d4a84c4..1283479e 100644 --- a/.test/test/default_plugin_test.go +++ b/.test/test/default_plugin_test.go @@ -36,7 +36,7 @@ import ( func TestDefaultPlugin_InitHostProvider(t *testing.T) { dp := &plugins.DefaultPlugin{} - err := dp.InitHostProvider("someUrl", map[string]string{}, nil, nil) + err := dp.InitHostProvider(map[string]string{}, nil, nil) assert.NoError(t, err) } diff --git a/.test/test/dsn_host_list_provider_test.go b/.test/test/dsn_host_list_provider_test.go index 0c739905..f5fb415d 100644 --- a/.test/test/dsn_host_list_provider_test.go +++ b/.test/test/dsn_host_list_provider_test.go @@ -23,7 +23,7 @@ import ( mock_driver_infrastructure "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/driver_infrastructure" "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" - "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -35,12 +35,9 @@ func TestDsnHostListProvider_Refresh_Success(t *testing.T) { mockHostListService := mock_driver_infrastructure.NewMockHostListProviderService(ctrl) // Construct DSN string with host dsn := "postgresql://127.0.0.1:5432/db" + props, _ := utils.ParseDsn(dsn) - props := map[string]string{ - property_util.HOST.Name: "127.0.0.1", - } - - provider := driver_infrastructure.NewDsnHostListProvider(props, dsn, mockHostListService) + provider := driver_infrastructure.NewDsnHostListProvider(props, mockHostListService) // `init()` should call SetInitialConnectionHostInfo with parsed host mockHostListService.EXPECT().SetInitialConnectionHostInfo(gomock.Any()).Times(1) @@ -59,11 +56,9 @@ func TestDsnHostListProvider_ForceRefresh_UsesInit(t *testing.T) { mockHostListService := mock_driver_infrastructure.NewMockHostListProviderService(ctrl) dsn := "postgresql://127.0.0.1:5432/db" - props := map[string]string{ - property_util.HOST.Name: "127.0.0.1", - } + props, _ := utils.ParseDsn(dsn) - provider := driver_infrastructure.NewDsnHostListProvider(props, dsn, mockHostListService) + provider := driver_infrastructure.NewDsnHostListProvider(props, mockHostListService) mockHostListService.EXPECT().SetInitialConnectionHostInfo(gomock.Any()).Times(1) @@ -78,49 +73,53 @@ func TestDsnHostListProvider_CreateHost_BuildsCorrectly(t *testing.T) { defer ctrl.Finish() mockHostListService := mock_driver_infrastructure.NewMockHostListProviderService(ctrl) - mockDialect := mock_driver_infrastructure.NewMockDatabaseDialect(ctrl) - mockHostListService.EXPECT().GetDialect().Return(mockDialect).Times(1) - mockDialect.EXPECT().GetDefaultPort().Return(3306).Times(1) + mockHostListService.EXPECT().GetDialect().Return(driver_infrastructure.DatabaseDialect(&driver_infrastructure.PgDatabaseDialect{})).Times(2) - props := map[string]string{ - property_util.HOST.Name: "127.0.0.1", - } + props, _ := utils.ParseDsn("postgresql://127.0.0.1:5432/db") - provider := driver_infrastructure.NewDsnHostListProvider(props, "postgresql://127.0.0.1:5432/db", mockHostListService) + provider := driver_infrastructure.NewDsnHostListProvider(props, mockHostListService) now := time.Now() - result := provider.CreateHost("some-host", host_info_util.READER, 1.0, 0.5, now) + result := provider.CreateHost("", host_info_util.READER, 1.0, 0.5, now) assert.Equal(t, "127.0.0.1", result.Host) - assert.Equal(t, 3306, result.Port) + assert.Equal(t, 5432, result.Port) assert.Equal(t, host_info_util.READER, result.Role) assert.Equal(t, host_info_util.AVAILABLE, result.Availability) - assert.Equal(t, 101, result.Weight) // 1.0 lag * 100 + 0.5 = 100.5 → rounds to 101 + assert.Equal(t, 101, result.Weight) // 1 lag * 100 + 0.5 = 100.5 → rounds to 101 + assert.Equal(t, now, result.LastUpdateTime) + + result = provider.CreateHost("some-host", host_info_util.WRITER, 2.1, 0.3, now) + assert.Equal(t, "some-host", result.Host) + assert.Equal(t, 5432, result.Port) + assert.Equal(t, host_info_util.WRITER, result.Role) + assert.Equal(t, host_info_util.AVAILABLE, result.Availability) + assert.Equal(t, 200, result.Weight) // 2 lag * 100 + 0.3 = 200.3 → rounds to 210 assert.Equal(t, now, result.LastUpdateTime) } func TestDsnHostListProvider_GetHostRole_ReturnsUnknown(t *testing.T) { - provider := driver_infrastructure.NewDsnHostListProvider(nil, "dsn", nil) + provider := driver_infrastructure.NewDsnHostListProvider(nil, nil) role := provider.GetHostRole(nil) assert.Equal(t, host_info_util.UNKNOWN, role) } func TestDsnHostListProvider_IdentifyConnection_Unsupported(t *testing.T) { - provider := driver_infrastructure.NewDsnHostListProvider(nil, "dsn", nil) + provider := driver_infrastructure.NewDsnHostListProvider(nil, nil) host, err := provider.IdentifyConnection(nil) assert.Nil(t, host) assert.Error(t, err) } func TestDsnHostListProvider_GetClusterId_Unsupported(t *testing.T) { - provider := driver_infrastructure.NewDsnHostListProvider(nil, "dsn", nil) + provider := driver_infrastructure.NewDsnHostListProvider(nil, nil) id, err := provider.GetClusterId() assert.Empty(t, id) assert.Error(t, err) } func TestDsnHostListProvider_IsStaticHostListProvider(t *testing.T) { - provider := driver_infrastructure.NewDsnHostListProvider(nil, "dsn", nil) + provider := driver_infrastructure.NewDsnHostListProvider(nil, nil) assert.True(t, provider.IsStaticHostListProvider()) } diff --git a/.test/test/failover_plugin_test.go b/.test/test/failover_plugin_test.go index 2673b4ce..439d2dbb 100644 --- a/.test/test/failover_plugin_test.go +++ b/.test/test/failover_plugin_test.go @@ -19,6 +19,9 @@ package test import ( "database/sql/driver" "errors" + "slices" + "testing" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" @@ -27,9 +30,7 @@ import ( "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" "github.com/aws/aws-advanced-go-wrapper/awssql/utils" "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" - "github.com/aws/aws-advanced-go-wrapper/mysql-driver" - "slices" - "testing" + mysql_driver "github.com/aws/aws-advanced-go-wrapper/mysql-driver" "github.com/stretchr/testify/assert" ) @@ -80,14 +81,12 @@ func newTestMockMonitoringRdsHostListProvider( hostListProviderService driver_infrastructure.HostListProviderService, databaseDialect driver_infrastructure.TopologyAwareDialect, properties map[string]string, - originalDsn string, pluginService driver_infrastructure.PluginService) *MockMonitoringRdsHostListProvider { provider := &MockMonitoringRdsHostListProvider{ MonitoringRdsHostListProvider: driver_infrastructure.NewMonitoringRdsHostListProvider( hostListProviderService, databaseDialect, properties, - originalDsn, pluginService, ), } @@ -163,7 +162,6 @@ type mockAuroraMysqlDialect struct { func (t *mockAuroraMysqlDialect) GetHostListProvider( props map[string]string, - initialDsn string, hostListProviderService driver_infrastructure.HostListProviderService, pluginService driver_infrastructure.PluginService) driver_infrastructure.HostListProvider { return mockMonitoringRdsHostListProvider @@ -250,13 +248,13 @@ func initializeTest( pluginServiceImpl.SetDialect(&mockAuroraMysqlDialect{isRoleWriter: isRoleWriter}) mockPluginService := driver_infrastructure.PluginService(pluginServiceImpl) + mySqlTestDsnProps, _ := utils.ParseDsn(mysqlTestDsn) hostListProviderService := driver_infrastructure.HostListProviderService(pluginServiceImpl) mockMonitoringRdsHostListProvider = newTestMockMonitoringRdsHostListProvider( hostListProviderService, &mockAuroraMysqlDialect{isRoleWriter: isRoleWriter}, - props, - mysqlTestDsn, + utils.CombineMaps(props, mySqlTestDsnProps), pluginServiceImpl) hostListProviderService.SetHostListProvider(mockMonitoringRdsHostListProvider) @@ -272,7 +270,7 @@ func initializeTest( failoverPlugin, _ := plugins.NewFailoverPlugin(pluginServiceImpl, props) mockFailoverPlugin := &MockFailoverPlugin{FailoverPlugin: failoverPlugin} _ = mockPluginManager.Init(mockPluginService, []driver_infrastructure.ConnectionPlugin{mockFailoverPlugin, &defaultPlugin}) - _ = mockPluginManager.InitHostProvider(mysqlTestDsn, props, hostListProviderService) + _ = mockPluginManager.InitHostProvider(props, hostListProviderService) return mockFailoverPlugin, pluginServiceImpl } diff --git a/.test/test/implementations_test.go b/.test/test/implementations_test.go index 250a5635..e982bd2f 100644 --- a/.test/test/implementations_test.go +++ b/.test/test/implementations_test.go @@ -35,6 +35,7 @@ import ( "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" "github.com/aws/aws-advanced-go-wrapper/awssql/plugin_helpers" "github.com/aws/aws-advanced-go-wrapper/awssql/plugins" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/bg" "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/efm" federated_auth "github.com/aws/aws-advanced-go-wrapper/federated-auth" "github.com/aws/aws-advanced-go-wrapper/iam" @@ -51,6 +52,10 @@ func TestImplementations(t *testing.T) { var _ driver_infrastructure.DatabaseDialect = (*driver_infrastructure.PgDatabaseDialect)(nil) var _ driver_infrastructure.DatabaseDialect = (*driver_infrastructure.RdsPgDatabaseDialect)(nil) var _ driver_infrastructure.DatabaseDialect = (*driver_infrastructure.AuroraPgDatabaseDialect)(nil) + var _ driver_infrastructure.BlueGreenDialect = (*driver_infrastructure.RdsMySQLDatabaseDialect)(nil) + var _ driver_infrastructure.BlueGreenDialect = (*driver_infrastructure.AuroraMySQLDatabaseDialect)(nil) + var _ driver_infrastructure.BlueGreenDialect = (*driver_infrastructure.RdsPgDatabaseDialect)(nil) + var _ driver_infrastructure.BlueGreenDialect = (*driver_infrastructure.AuroraPgDatabaseDialect)(nil) var _ driver_infrastructure.TopologyAwareDialect = (*driver_infrastructure.MySQLTopologyAwareDatabaseDialect)(nil) var _ driver_infrastructure.TopologyAwareDialect = (*driver_infrastructure.PgTopologyAwareDatabaseDialect)(nil) var _ driver_infrastructure.TopologyAwareDialect = (*driver_infrastructure.AuroraMySQLDatabaseDialect)(nil) @@ -80,8 +85,10 @@ func TestImplementations(t *testing.T) { var _ driver_infrastructure.ConnectionPlugin = (*aws_secrets_manager.AwsSecretsManagerPlugin)(nil) var _ driver_infrastructure.ConnectionPlugin = (*okta.OktaAuthPlugin)(nil) var _ driver_infrastructure.ConnectionPlugin = (*federated_auth.FederatedAuthPlugin)(nil) + var _ driver_infrastructure.ConnectionPlugin = (*bg.BlueGreenPlugin)(nil) var _ driver_infrastructure.ConnectionPluginFactory = (*efm.HostMonitoringPluginFactory)(nil) var _ driver_infrastructure.ConnectionPluginFactory = (*plugins.FailoverPluginFactory)(nil) + var _ driver_infrastructure.ConnectionPluginFactory = (*bg.BlueGreenPluginFactory)(nil) var _ driver_infrastructure.ConnectionPluginFactory = (*iam.IamAuthPluginFactory)(nil) var _ driver_infrastructure.ConnectionPluginFactory = (*aws_secrets_manager.AwsSecretsManagerPluginFactory)(nil) var _ driver_infrastructure.ConnectionPluginFactory = (*okta.OktaAuthPluginFactory)(nil) @@ -129,4 +136,9 @@ func TestImplementations(t *testing.T) { var _ driver.RowsColumnTypeScanType = (*awsDriver.AwsWrapperMySQLRows)(nil) var _ driver.RowsColumnTypeNullable = (*awsDriver.AwsWrapperMySQLRows)(nil) var _ error = (*error_util.AwsWrapperError)(nil) + var _ driver_infrastructure.ExecuteRouting = (*bg.SuspendExecuteRouting)(nil) + var _ driver_infrastructure.ConnectRouting = (*bg.SuspendConnectRouting)(nil) + var _ driver_infrastructure.ConnectRouting = (*bg.SubstituteConnectRouting)(nil) + var _ driver_infrastructure.ConnectRouting = (*bg.RejectConnectRouting)(nil) + var _ driver_infrastructure.ConnectRouting = (*bg.SuspendUntilCorrespondingHostFoundConnectRouting)(nil) } diff --git a/.test/test/limitless_query_helper_test.go b/.test/test/limitless_query_helper_test.go index 17d49314..a45a6b86 100644 --- a/.test/test/limitless_query_helper_test.go +++ b/.test/test/limitless_query_helper_test.go @@ -23,6 +23,7 @@ import ( mock_driver_infrastructure "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/awssql/driver_infrastructure" mock_database_sql_driver "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/database_sql_driver" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/limitless" "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" "github.com/golang/mock/gomock" @@ -44,16 +45,15 @@ func TestQueryForLimitlessRouters_Success(t *testing.T) { MockQueryerContext: mock_database_sql_driver.NewMockQueryerContext(ctrl), } mockPlugin := mock_driver_infrastructure.NewMockPluginService(ctrl) - mockDialect := mock_driver_infrastructure.NewMockAuroraLimitlessDialect(ctrl) mockRows := mock_database_sql_driver.NewMockRows(ctrl) + dialect := &driver_infrastructure.AuroraPgDatabaseDialect{} + query := dialect.GetLimitlessRouterEndpointQuery() props := map[string]string{ property_util.LIMITLESS_ROUTER_QUERY_TIMEOUT_MS.Name: "100", } - query := "SELECT * FROM limitless_router" - mockPlugin.EXPECT().GetDialect().Return(mockDialect) - mockDialect.EXPECT().GetLimitlessRouterEndpointQuery().Return(query) + mockPlugin.EXPECT().GetDialect().Return(dialect) mockConn.MockQueryerContext.EXPECT(). QueryContext(gomock.Any(), query, gomock.Nil()). @@ -84,10 +84,9 @@ func TestQueryForLimitlessRouters_ConnDoesNotImplementQueryerContext(t *testing. mockConn := mock_database_sql_driver.NewMockConn(ctrl) mockPlugin := mock_driver_infrastructure.NewMockPluginService(ctrl) - mockDialect := mock_driver_infrastructure.NewMockAuroraLimitlessDialect(ctrl) + dialect := &driver_infrastructure.AuroraPgDatabaseDialect{} - mockPlugin.EXPECT().GetDialect().Return(mockDialect) - mockDialect.EXPECT().GetLimitlessRouterEndpointQuery().Return("SELECT") + mockPlugin.EXPECT().GetDialect().Return(dialect) props := map[string]string{ property_util.LIMITLESS_ROUTER_QUERY_TIMEOUT_MS.Name: "100", @@ -109,10 +108,9 @@ func TestQueryForLimitlessRouters_InvalidTimeoutProperty(t *testing.T) { MockQueryerContext: mock_database_sql_driver.NewMockQueryerContext(ctrl), } mockPlugin := mock_driver_infrastructure.NewMockPluginService(ctrl) - mockDialect := mock_driver_infrastructure.NewMockAuroraLimitlessDialect(ctrl) + dialect := &driver_infrastructure.AuroraPgDatabaseDialect{} - mockPlugin.EXPECT().GetDialect().Return(mockDialect) - mockDialect.EXPECT().GetLimitlessRouterEndpointQuery().Return("SELECT") + mockPlugin.EXPECT().GetDialect().Return(dialect) mockConn.MockQueryerContext. EXPECT(). QueryContext(gomock.Any(), gomock.Any(), gomock.Any()). @@ -137,13 +135,12 @@ func TestQueryForLimitlessRouters_QueryFails(t *testing.T) { MockQueryerContext: mock_database_sql_driver.NewMockQueryerContext(ctrl), } mockPlugin := mock_driver_infrastructure.NewMockPluginService(ctrl) - mockDialect := mock_driver_infrastructure.NewMockAuroraLimitlessDialect(ctrl) + dialect := &driver_infrastructure.AuroraPgDatabaseDialect{} - mockPlugin.EXPECT().GetDialect().Return(mockDialect) - mockDialect.EXPECT().GetLimitlessRouterEndpointQuery().Return("SELECT") + mockPlugin.EXPECT().GetDialect().Return(dialect) mockConn.MockQueryerContext.EXPECT(). - QueryContext(gomock.Any(), "SELECT", gomock.Nil()). + QueryContext(gomock.Any(), gomock.Any(), gomock.Nil()). Return(nil, errors.New("query error")) props := map[string]string{ @@ -166,13 +163,12 @@ func TestQueryForLimitlessRouters_EmptyHostNameOrBadType(t *testing.T) { MockQueryerContext: mock_database_sql_driver.NewMockQueryerContext(ctrl), } mockPlugin := mock_driver_infrastructure.NewMockPluginService(ctrl) - mockDialect := mock_driver_infrastructure.NewMockAuroraLimitlessDialect(ctrl) + dialect := &driver_infrastructure.AuroraPgDatabaseDialect{} mockRows := mock_database_sql_driver.NewMockRows(ctrl) - mockPlugin.EXPECT().GetDialect().Return(mockDialect) - mockDialect.EXPECT().GetLimitlessRouterEndpointQuery().Return("SELECT") + mockPlugin.EXPECT().GetDialect().Return(dialect) mockConn.MockQueryerContext.EXPECT(). - QueryContext(gomock.Any(), "SELECT", gomock.Nil()). + QueryContext(gomock.Any(), gomock.Any(), gomock.Nil()). Return(mockRows, nil) mockRows.EXPECT().Columns().Return([]string{"host", "load"}) diff --git a/.test/test/mock_implementations.go b/.test/test/mock_implementations.go index fdadb6a4..fab03f5a 100644 --- a/.test/test/mock_implementations.go +++ b/.test/test/mock_implementations.go @@ -51,6 +51,8 @@ var defaultPluginFactoryByCode = map[string]driver_infrastructure.ConnectionPlug "executionTime": plugins.NewExecutionTimePluginFactory(), } +var testPluginCode string = "test" + type TestPlugin struct { calls *[]string id int @@ -59,6 +61,10 @@ type TestPlugin struct { isBefore bool } +func (t TestPlugin) GetPluginCode() string { + return testPluginCode +} + func (t TestPlugin) GetSubscribedMethods() []string { switch t.id { case 1: @@ -147,7 +153,6 @@ func (t TestPlugin) NotifyHostListChanged(changes map[string]map[driver_infrastr } func (t TestPlugin) InitHostProvider( - initialUrl string, props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, initHostProviderFunc func() error) error { @@ -186,7 +191,7 @@ func CreateTestPlugin(calls *[]string, id int, connection driver.Conn, err error if calls == nil { calls = &[]string{} } - testPlugin := driver_infrastructure.ConnectionPlugin(&TestPlugin{calls: calls, id: id, connection: connection, error: err, isBefore: isBefore}) + testPlugin := &TestPlugin{calls: calls, id: id, connection: connection, error: err, isBefore: isBefore} return testPlugin } @@ -424,7 +429,7 @@ func (m *MockPluginService) GetCurrentTx() driver.Tx { func (m *MockPluginService) SetCurrentTx(tx driver.Tx) {} -func (m *MockPluginService) CreateHostListProvider(props map[string]string, dsn string) driver_infrastructure.HostListProvider { +func (m *MockPluginService) CreateHostListProvider(props map[string]string) driver_infrastructure.HostListProvider { return nil } @@ -515,6 +520,17 @@ func (m *MockPluginService) IsReadOnly() bool { return false } +func (p *MockPluginService) GetBgStatus(id string) (driver_infrastructure.BlueGreenStatus, bool) { + return driver_infrastructure.BlueGreenStatus{}, true +} + +func (p *MockPluginService) SetBgStatus(status driver_infrastructure.BlueGreenStatus, id string) { +} + +func (p *MockPluginService) IsPluginInUse(pluginCode string) bool { + return false +} + type MockDriverConn struct { driver.Conn } @@ -558,7 +574,7 @@ func (m *MockRdsHostListProviderService) IsStaticHostListProvider() bool { return false } -func (m *MockRdsHostListProviderService) CreateHostListProvider(props map[string]string, dsn string) driver_infrastructure.HostListProvider { +func (m *MockRdsHostListProviderService) CreateHostListProvider(props map[string]string) driver_infrastructure.HostListProvider { return nil } @@ -654,7 +670,7 @@ func (m MockHttpClient) Do(req *http.Request) (*http.Response, error) { resp := m.doReturnValues[idx] - (*m.doCallCount)++ + *m.doCallCount++ return resp, m.errReturnValue } diff --git a/.test/test/mocks/awssql/driver_infrastructure/mock_connection_plugin.go b/.test/test/mocks/awssql/driver_infrastructure/mock_connection_plugin.go index f5d81819..43a01b51 100644 --- a/.test/test/mocks/awssql/driver_infrastructure/mock_connection_plugin.go +++ b/.test/test/mocks/awssql/driver_infrastructure/mock_connection_plugin.go @@ -148,6 +148,20 @@ func (mr *MockConnectionPluginMockRecorder) GetHostSelectorStrategy(strategy int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostSelectorStrategy", reflect.TypeOf((*MockConnectionPlugin)(nil).GetHostSelectorStrategy), strategy) } +// GetPluginCode mocks base method. +func (m *MockConnectionPlugin) GetPluginCode() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPluginCode") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetPluginCode indicates an expected call of GetPluginCode. +func (mr *MockConnectionPluginMockRecorder) GetPluginCode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPluginCode", reflect.TypeOf((*MockConnectionPlugin)(nil).GetPluginCode)) +} + // GetSubscribedMethods mocks base method. func (m *MockConnectionPlugin) GetSubscribedMethods() []string { m.ctrl.T.Helper() @@ -163,17 +177,17 @@ func (mr *MockConnectionPluginMockRecorder) GetSubscribedMethods() *gomock.Call } // InitHostProvider mocks base method. -func (m *MockConnectionPlugin) InitHostProvider(initialUrl string, props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, initHostProviderFunc func() error) error { +func (m *MockConnectionPlugin) InitHostProvider(props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, initHostProviderFunc func() error) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InitHostProvider", initialUrl, props, hostListProviderService, initHostProviderFunc) + ret := m.ctrl.Call(m, "InitHostProvider", props, hostListProviderService, initHostProviderFunc) ret0, _ := ret[0].(error) return ret0 } // InitHostProvider indicates an expected call of InitHostProvider. -func (mr *MockConnectionPluginMockRecorder) InitHostProvider(initialUrl, props, hostListProviderService, initHostProviderFunc interface{}) *gomock.Call { +func (mr *MockConnectionPluginMockRecorder) InitHostProvider(props, hostListProviderService, initHostProviderFunc interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitHostProvider", reflect.TypeOf((*MockConnectionPlugin)(nil).InitHostProvider), initialUrl, props, hostListProviderService, initHostProviderFunc) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitHostProvider", reflect.TypeOf((*MockConnectionPlugin)(nil).InitHostProvider), props, hostListProviderService, initHostProviderFunc) } // NotifyConnectionChanged mocks base method. diff --git a/.test/test/mocks/awssql/driver_infrastructure/mock_database_dialect.go b/.test/test/mocks/awssql/driver_infrastructure/mock_database_dialect.go index 95c228de..06ddbb45 100644 --- a/.test/test/mocks/awssql/driver_infrastructure/mock_database_dialect.go +++ b/.test/test/mocks/awssql/driver_infrastructure/mock_database_dialect.go @@ -170,17 +170,17 @@ func (mr *MockDatabaseDialectMockRecorder) GetHostAliasQuery() *gomock.Call { } // GetHostListProvider mocks base method. -func (m *MockDatabaseDialect) GetHostListProvider(props map[string]string, initialDsn string, hostListProviderService driver_infrastructure.HostListProviderService, pluginService driver_infrastructure.PluginService) driver_infrastructure.HostListProvider { +func (m *MockDatabaseDialect) GetHostListProvider(props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, pluginService driver_infrastructure.PluginService) driver_infrastructure.HostListProvider { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHostListProvider", props, initialDsn, hostListProviderService, pluginService) + ret := m.ctrl.Call(m, "GetHostListProvider", props, hostListProviderService, pluginService) ret0, _ := ret[0].(driver_infrastructure.HostListProvider) return ret0 } // GetHostListProvider indicates an expected call of GetHostListProvider. -func (mr *MockDatabaseDialectMockRecorder) GetHostListProvider(props, initialDsn, hostListProviderService, pluginService interface{}) *gomock.Call { +func (mr *MockDatabaseDialectMockRecorder) GetHostListProvider(props, hostListProviderService, pluginService interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostListProvider", reflect.TypeOf((*MockDatabaseDialect)(nil).GetHostListProvider), props, initialDsn, hostListProviderService, pluginService) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostListProvider", reflect.TypeOf((*MockDatabaseDialect)(nil).GetHostListProvider), props, hostListProviderService, pluginService) } // GetServerVersionQuery mocks base method. @@ -427,17 +427,17 @@ func (mr *MockTopologyAwareDialectMockRecorder) GetHostAliasQuery() *gomock.Call } // GetHostListProvider mocks base method. -func (m *MockTopologyAwareDialect) GetHostListProvider(props map[string]string, initialDsn string, hostListProviderService driver_infrastructure.HostListProviderService, pluginService driver_infrastructure.PluginService) driver_infrastructure.HostListProvider { +func (m *MockTopologyAwareDialect) GetHostListProvider(props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, pluginService driver_infrastructure.PluginService) driver_infrastructure.HostListProvider { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHostListProvider", props, initialDsn, hostListProviderService, pluginService) + ret := m.ctrl.Call(m, "GetHostListProvider", props, hostListProviderService, pluginService) ret0, _ := ret[0].(driver_infrastructure.HostListProvider) return ret0 } // GetHostListProvider indicates an expected call of GetHostListProvider. -func (mr *MockTopologyAwareDialectMockRecorder) GetHostListProvider(props, initialDsn, hostListProviderService, pluginService interface{}) *gomock.Call { +func (mr *MockTopologyAwareDialectMockRecorder) GetHostListProvider(props, hostListProviderService, pluginService interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostListProvider", reflect.TypeOf((*MockTopologyAwareDialect)(nil).GetHostListProvider), props, initialDsn, hostListProviderService, pluginService) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostListProvider", reflect.TypeOf((*MockTopologyAwareDialect)(nil).GetHostListProvider), props, hostListProviderService, pluginService) } // GetHostName mocks base method. @@ -742,17 +742,17 @@ func (mr *MockAuroraLimitlessDialectMockRecorder) GetHostAliasQuery() *gomock.Ca } // GetHostListProvider mocks base method. -func (m *MockAuroraLimitlessDialect) GetHostListProvider(props map[string]string, initialDsn string, hostListProviderService driver_infrastructure.HostListProviderService, pluginService driver_infrastructure.PluginService) driver_infrastructure.HostListProvider { +func (m *MockAuroraLimitlessDialect) GetHostListProvider(props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, pluginService driver_infrastructure.PluginService) driver_infrastructure.HostListProvider { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHostListProvider", props, initialDsn, hostListProviderService, pluginService) + ret := m.ctrl.Call(m, "GetHostListProvider", props, hostListProviderService, pluginService) ret0, _ := ret[0].(driver_infrastructure.HostListProvider) return ret0 } // GetHostListProvider indicates an expected call of GetHostListProvider. -func (mr *MockAuroraLimitlessDialectMockRecorder) GetHostListProvider(props, initialDsn, hostListProviderService, pluginService interface{}) *gomock.Call { +func (mr *MockAuroraLimitlessDialectMockRecorder) GetHostListProvider(props, hostListProviderService, pluginService interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostListProvider", reflect.TypeOf((*MockAuroraLimitlessDialect)(nil).GetHostListProvider), props, initialDsn, hostListProviderService, pluginService) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostListProvider", reflect.TypeOf((*MockAuroraLimitlessDialect)(nil).GetHostListProvider), props, hostListProviderService, pluginService) } // GetLimitlessRouterEndpointQuery mocks base method. @@ -871,3 +871,288 @@ func (mr *MockAuroraLimitlessDialectMockRecorder) IsDialect(conn interface{}) *g mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDialect", reflect.TypeOf((*MockAuroraLimitlessDialect)(nil).IsDialect), conn) } + +// MockBlueGreenDialect is a mock of BlueGreenDialect interface. +type MockBlueGreenDialect struct { + ctrl *gomock.Controller + recorder *MockBlueGreenDialectMockRecorder +} + +// MockBlueGreenDialectMockRecorder is the mock recorder for MockBlueGreenDialect. +type MockBlueGreenDialectMockRecorder struct { + mock *MockBlueGreenDialect +} + +// NewMockBlueGreenDialect creates a new mock instance. +func NewMockBlueGreenDialect(ctrl *gomock.Controller) *MockBlueGreenDialect { + mock := &MockBlueGreenDialect{ctrl: ctrl} + mock.recorder = &MockBlueGreenDialectMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBlueGreenDialect) EXPECT() *MockBlueGreenDialectMockRecorder { + return m.recorder +} + +// DoesStatementSetAutoCommit mocks base method. +func (m *MockBlueGreenDialect) DoesStatementSetAutoCommit(statement string) (bool, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoesStatementSetAutoCommit", statement) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// DoesStatementSetAutoCommit indicates an expected call of DoesStatementSetAutoCommit. +func (mr *MockBlueGreenDialectMockRecorder) DoesStatementSetAutoCommit(statement interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesStatementSetAutoCommit", reflect.TypeOf((*MockBlueGreenDialect)(nil).DoesStatementSetAutoCommit), statement) +} + +// DoesStatementSetCatalog mocks base method. +func (m *MockBlueGreenDialect) DoesStatementSetCatalog(statement string) (string, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoesStatementSetCatalog", statement) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// DoesStatementSetCatalog indicates an expected call of DoesStatementSetCatalog. +func (mr *MockBlueGreenDialectMockRecorder) DoesStatementSetCatalog(statement interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesStatementSetCatalog", reflect.TypeOf((*MockBlueGreenDialect)(nil).DoesStatementSetCatalog), statement) +} + +// DoesStatementSetReadOnly mocks base method. +func (m *MockBlueGreenDialect) DoesStatementSetReadOnly(statement string) (bool, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoesStatementSetReadOnly", statement) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// DoesStatementSetReadOnly indicates an expected call of DoesStatementSetReadOnly. +func (mr *MockBlueGreenDialectMockRecorder) DoesStatementSetReadOnly(statement interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesStatementSetReadOnly", reflect.TypeOf((*MockBlueGreenDialect)(nil).DoesStatementSetReadOnly), statement) +} + +// DoesStatementSetSchema mocks base method. +func (m *MockBlueGreenDialect) DoesStatementSetSchema(statement string) (string, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoesStatementSetSchema", statement) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// DoesStatementSetSchema indicates an expected call of DoesStatementSetSchema. +func (mr *MockBlueGreenDialectMockRecorder) DoesStatementSetSchema(statement interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesStatementSetSchema", reflect.TypeOf((*MockBlueGreenDialect)(nil).DoesStatementSetSchema), statement) +} + +// DoesStatementSetTransactionIsolation mocks base method. +func (m *MockBlueGreenDialect) DoesStatementSetTransactionIsolation(statement string) (driver_infrastructure.TransactionIsolationLevel, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoesStatementSetTransactionIsolation", statement) + ret0, _ := ret[0].(driver_infrastructure.TransactionIsolationLevel) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// DoesStatementSetTransactionIsolation indicates an expected call of DoesStatementSetTransactionIsolation. +func (mr *MockBlueGreenDialectMockRecorder) DoesStatementSetTransactionIsolation(statement interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoesStatementSetTransactionIsolation", reflect.TypeOf((*MockBlueGreenDialect)(nil).DoesStatementSetTransactionIsolation), statement) +} + +// GetBlueGreenStatus mocks base method. +func (m *MockBlueGreenDialect) GetBlueGreenStatus(conn driver.Conn) []driver_infrastructure.BlueGreenResult { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBlueGreenStatus", conn) + ret0, _ := ret[0].([]driver_infrastructure.BlueGreenResult) + return ret0 +} + +// GetBlueGreenStatus indicates an expected call of GetBlueGreenStatus. +func (mr *MockBlueGreenDialectMockRecorder) GetBlueGreenStatus(conn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBlueGreenStatus", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetBlueGreenStatus), conn) +} + +// GetDefaultPort mocks base method. +func (m *MockBlueGreenDialect) GetDefaultPort() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDefaultPort") + ret0, _ := ret[0].(int) + return ret0 +} + +// GetDefaultPort indicates an expected call of GetDefaultPort. +func (mr *MockBlueGreenDialectMockRecorder) GetDefaultPort() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultPort", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetDefaultPort)) +} + +// GetDialectUpdateCandidates mocks base method. +func (m *MockBlueGreenDialect) GetDialectUpdateCandidates() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDialectUpdateCandidates") + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetDialectUpdateCandidates indicates an expected call of GetDialectUpdateCandidates. +func (mr *MockBlueGreenDialectMockRecorder) GetDialectUpdateCandidates() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDialectUpdateCandidates", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetDialectUpdateCandidates)) +} + +// GetHostAliasQuery mocks base method. +func (m *MockBlueGreenDialect) GetHostAliasQuery() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHostAliasQuery") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetHostAliasQuery indicates an expected call of GetHostAliasQuery. +func (mr *MockBlueGreenDialectMockRecorder) GetHostAliasQuery() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostAliasQuery", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetHostAliasQuery)) +} + +// GetHostListProvider mocks base method. +func (m *MockBlueGreenDialect) GetHostListProvider(props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, pluginService driver_infrastructure.PluginService) driver_infrastructure.HostListProvider { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHostListProvider", props, hostListProviderService, pluginService) + ret0, _ := ret[0].(driver_infrastructure.HostListProvider) + return ret0 +} + +// GetHostListProvider indicates an expected call of GetHostListProvider. +func (mr *MockBlueGreenDialectMockRecorder) GetHostListProvider(props, hostListProviderService, pluginService interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostListProvider", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetHostListProvider), props, hostListProviderService, pluginService) +} + +// GetServerVersionQuery mocks base method. +func (m *MockBlueGreenDialect) GetServerVersionQuery() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServerVersionQuery") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetServerVersionQuery indicates an expected call of GetServerVersionQuery. +func (mr *MockBlueGreenDialectMockRecorder) GetServerVersionQuery() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerVersionQuery", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetServerVersionQuery)) +} + +// GetSetAutoCommitQuery mocks base method. +func (m *MockBlueGreenDialect) GetSetAutoCommitQuery(autoCommit bool) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSetAutoCommitQuery", autoCommit) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSetAutoCommitQuery indicates an expected call of GetSetAutoCommitQuery. +func (mr *MockBlueGreenDialectMockRecorder) GetSetAutoCommitQuery(autoCommit interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSetAutoCommitQuery", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetSetAutoCommitQuery), autoCommit) +} + +// GetSetCatalogQuery mocks base method. +func (m *MockBlueGreenDialect) GetSetCatalogQuery(catalog string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSetCatalogQuery", catalog) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSetCatalogQuery indicates an expected call of GetSetCatalogQuery. +func (mr *MockBlueGreenDialectMockRecorder) GetSetCatalogQuery(catalog interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSetCatalogQuery", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetSetCatalogQuery), catalog) +} + +// GetSetReadOnlyQuery mocks base method. +func (m *MockBlueGreenDialect) GetSetReadOnlyQuery(readOnly bool) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSetReadOnlyQuery", readOnly) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSetReadOnlyQuery indicates an expected call of GetSetReadOnlyQuery. +func (mr *MockBlueGreenDialectMockRecorder) GetSetReadOnlyQuery(readOnly interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSetReadOnlyQuery", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetSetReadOnlyQuery), readOnly) +} + +// GetSetSchemaQuery mocks base method. +func (m *MockBlueGreenDialect) GetSetSchemaQuery(schema string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSetSchemaQuery", schema) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSetSchemaQuery indicates an expected call of GetSetSchemaQuery. +func (mr *MockBlueGreenDialectMockRecorder) GetSetSchemaQuery(schema interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSetSchemaQuery", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetSetSchemaQuery), schema) +} + +// GetSetTransactionIsolationQuery mocks base method. +func (m *MockBlueGreenDialect) GetSetTransactionIsolationQuery(level driver_infrastructure.TransactionIsolationLevel) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSetTransactionIsolationQuery", level) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSetTransactionIsolationQuery indicates an expected call of GetSetTransactionIsolationQuery. +func (mr *MockBlueGreenDialectMockRecorder) GetSetTransactionIsolationQuery(level interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSetTransactionIsolationQuery", reflect.TypeOf((*MockBlueGreenDialect)(nil).GetSetTransactionIsolationQuery), level) +} + +// IsBlueGreenStatusAvailable mocks base method. +func (m *MockBlueGreenDialect) IsBlueGreenStatusAvailable(conn driver.Conn) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsBlueGreenStatusAvailable", conn) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsBlueGreenStatusAvailable indicates an expected call of IsBlueGreenStatusAvailable. +func (mr *MockBlueGreenDialectMockRecorder) IsBlueGreenStatusAvailable(conn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBlueGreenStatusAvailable", reflect.TypeOf((*MockBlueGreenDialect)(nil).IsBlueGreenStatusAvailable), conn) +} + +// IsDialect mocks base method. +func (m *MockBlueGreenDialect) IsDialect(conn driver.Conn) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsDialect", conn) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsDialect indicates an expected call of IsDialect. +func (mr *MockBlueGreenDialectMockRecorder) IsDialect(conn interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDialect", reflect.TypeOf((*MockBlueGreenDialect)(nil).IsDialect), conn) +} diff --git a/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go b/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go index 6039bd45..8db2b1f9 100644 --- a/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go +++ b/.test/test/mocks/awssql/driver_infrastructure/mock_plugin_helpers.go @@ -55,17 +55,17 @@ func (m *MockHostListProviderService) EXPECT() *MockHostListProviderServiceMockR } // CreateHostListProvider mocks base method. -func (m *MockHostListProviderService) CreateHostListProvider(props map[string]string, dsn string) driver_infrastructure.HostListProvider { +func (m *MockHostListProviderService) CreateHostListProvider(props map[string]string) driver_infrastructure.HostListProvider { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateHostListProvider", props, dsn) + ret := m.ctrl.Call(m, "CreateHostListProvider", props) ret0, _ := ret[0].(driver_infrastructure.HostListProvider) return ret0 } // CreateHostListProvider indicates an expected call of CreateHostListProvider. -func (mr *MockHostListProviderServiceMockRecorder) CreateHostListProvider(props, dsn interface{}) *gomock.Call { +func (mr *MockHostListProviderServiceMockRecorder) CreateHostListProvider(props interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHostListProvider", reflect.TypeOf((*MockHostListProviderService)(nil).CreateHostListProvider), props, dsn) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHostListProvider", reflect.TypeOf((*MockHostListProviderService)(nil).CreateHostListProvider), props) } // GetCurrentConnection mocks base method. @@ -201,17 +201,17 @@ func (mr *MockPluginServiceMockRecorder) Connect(hostInfo, props, pluginToSkip i } // CreateHostListProvider mocks base method. -func (m *MockPluginService) CreateHostListProvider(props map[string]string, dsn string) driver_infrastructure.HostListProvider { +func (m *MockPluginService) CreateHostListProvider(props map[string]string) driver_infrastructure.HostListProvider { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateHostListProvider", props, dsn) + ret := m.ctrl.Call(m, "CreateHostListProvider", props) ret0, _ := ret[0].(driver_infrastructure.HostListProvider) return ret0 } // CreateHostListProvider indicates an expected call of CreateHostListProvider. -func (mr *MockPluginServiceMockRecorder) CreateHostListProvider(props, dsn interface{}) *gomock.Call { +func (mr *MockPluginServiceMockRecorder) CreateHostListProvider(props interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHostListProvider", reflect.TypeOf((*MockPluginService)(nil).CreateHostListProvider), props, dsn) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHostListProvider", reflect.TypeOf((*MockPluginService)(nil).CreateHostListProvider), props) } // FillAliases mocks base method. @@ -455,6 +455,21 @@ func (mr *MockPluginServiceMockRecorder) GetProperties() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProperties", reflect.TypeOf((*MockPluginService)(nil).GetProperties)) } +// GetStatus mocks base method. +func (m *MockPluginService) GetBgStatus(id string) (driver_infrastructure.BlueGreenStatus, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBgStatus", id) + ret0, _ := ret[0].(driver_infrastructure.BlueGreenStatus) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetStatus indicates an expected call of GetStatus. +func (mr *MockPluginServiceMockRecorder) GetStatus(id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBgStatus", reflect.TypeOf((*MockPluginService)(nil).GetBgStatus), id) +} + // GetTargetDriverDialect mocks base method. func (m *MockPluginService) GetTargetDriverDialect() driver_infrastructure.DriverDialect { m.ctrl.T.Helper() @@ -569,6 +584,20 @@ func (mr *MockPluginServiceMockRecorder) IsNetworkError(err interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsNetworkError", reflect.TypeOf((*MockPluginService)(nil).IsNetworkError), err) } +// IsPluginInUse mocks base method. +func (m *MockPluginService) IsPluginInUse(pluginName string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsPluginInUse", pluginName) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsPluginInUse indicates an expected call of IsPluginInUse. +func (mr *MockPluginServiceMockRecorder) IsPluginInUse(pluginName interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPluginInUse", reflect.TypeOf((*MockPluginService)(nil).IsPluginInUse), pluginName) +} + // IsReadOnly mocks base method. func (m *MockPluginService) IsReadOnly() bool { m.ctrl.T.Helper() @@ -697,6 +726,18 @@ func (mr *MockPluginServiceMockRecorder) SetInitialConnectionHostInfo(info inter return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetInitialConnectionHostInfo", reflect.TypeOf((*MockPluginService)(nil).SetInitialConnectionHostInfo), info) } +// SetStatus mocks base method. +func (m *MockPluginService) SetBgStatus(status driver_infrastructure.BlueGreenStatus, id string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetBgStatus", status, id) +} + +// SetStatus indicates an expected call of SetStatus. +func (mr *MockPluginServiceMockRecorder) SetStatus(status, id interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBgStatus", reflect.TypeOf((*MockPluginService)(nil).SetBgStatus), status, id) +} + // SetTelemetryContext mocks base method. func (m *MockPluginService) SetTelemetryContext(ctx context.Context) { m.ctrl.T.Helper() @@ -942,17 +983,31 @@ func (mr *MockPluginManagerMockRecorder) Init(pluginService, plugins interface{} } // InitHostProvider mocks base method. -func (m *MockPluginManager) InitHostProvider(initialUrl string, props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService) error { +func (m *MockPluginManager) InitHostProvider(props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InitHostProvider", initialUrl, props, hostListProviderService) + ret := m.ctrl.Call(m, "InitHostProvider", props, hostListProviderService) ret0, _ := ret[0].(error) return ret0 } // InitHostProvider indicates an expected call of InitHostProvider. -func (mr *MockPluginManagerMockRecorder) InitHostProvider(initialUrl, props, hostListProviderService interface{}) *gomock.Call { +func (mr *MockPluginManagerMockRecorder) InitHostProvider(props, hostListProviderService interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitHostProvider", reflect.TypeOf((*MockPluginManager)(nil).InitHostProvider), props, hostListProviderService) +} + +// IsPluginInUse mocks base method. +func (m *MockPluginManager) IsPluginInUse(pluginName string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsPluginInUse", pluginName) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsPluginInUse indicates an expected call of IsPluginInUse. +func (mr *MockPluginManagerMockRecorder) IsPluginInUse(pluginName interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitHostProvider", reflect.TypeOf((*MockPluginManager)(nil).InitHostProvider), initialUrl, props, hostListProviderService) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPluginInUse", reflect.TypeOf((*MockPluginManager)(nil).IsPluginInUse), pluginName) } // NotifyConnectionChanged mocks base method. @@ -1019,6 +1074,20 @@ func (mr *MockPluginManagerMockRecorder) SetTelemetryContext(ctx interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTelemetryContext", reflect.TypeOf((*MockPluginManager)(nil).SetTelemetryContext), ctx) } +// UnwrapPlugin mocks base method. +func (m *MockPluginManager) UnwrapPlugin(pluginCode string) driver_infrastructure.ConnectionPlugin { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnwrapPlugin", pluginCode) + ret0, _ := ret[0].(driver_infrastructure.ConnectionPlugin) + return ret0 +} + +// UnwrapPlugin indicates an expected call of UnwrapPlugin. +func (mr *MockPluginManagerMockRecorder) UnwrapPlugin(pluginCode interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnwrapPlugin", reflect.TypeOf((*MockPluginManager)(nil).UnwrapPlugin), pluginCode) +} + // MockCanReleaseResources is a mock of CanReleaseResources interface. type MockCanReleaseResources struct { ctrl *gomock.Controller diff --git a/.test/test/mysql_database_dialects_test.go b/.test/test/mysql_database_dialects_test.go index ed4d946e..c04c08d5 100644 --- a/.test/test/mysql_database_dialects_test.go +++ b/.test/test/mysql_database_dialects_test.go @@ -113,7 +113,6 @@ func TestMySQLDatabaseDialect_GetHostListProvider(t *testing.T) { testDatabaseDialect := &driver_infrastructure.MySQLDatabaseDialect{} hostListProvider := testDatabaseDialect.GetHostListProvider( make(map[string]string), - "dsn", nil, nil) @@ -202,7 +201,6 @@ func TestRdsMySQLDatabaseDialect_GetHostListProvider(t *testing.T) { testDatabaseDialect := &driver_infrastructure.RdsMySQLDatabaseDialect{} hostListProvider := testDatabaseDialect.GetHostListProvider( make(map[string]string), - "dsn", nil, nil) @@ -294,7 +292,6 @@ func TestAuroraRdsMySQLDatabaseDialect_GetHostListProvider(t *testing.T) { property_util.PLUGINS.Set(propsNoFailover, "efm") hostListProvider := testDatabaseDialect.GetHostListProvider( propsNoFailover, - "dsn", nil, nil) @@ -306,7 +303,6 @@ func TestAuroraRdsMySQLDatabaseDialect_GetHostListProvider(t *testing.T) { property_util.PLUGINS.Set(propsWithFailover, "failover") hostListProvider = testDatabaseDialect.GetHostListProvider( propsWithFailover, - "dsn", nil, nil) @@ -552,7 +548,7 @@ func TestAuroraRdsMySQLDatabaseDialect_GetWriterHostName(t *testing.T) { func TestRdsMultiAzClusterMySQLDatabaseDialect_GetDialectUpdateCandidates(t *testing.T) { testDatabaseDialect := &driver_infrastructure.RdsMultiAzClusterMySQLDatabaseDialect{} - expectedCandidates := []string{} + var expectedCandidates []string assert.ElementsMatch(t, expectedCandidates, testDatabaseDialect.GetDialectUpdateCandidates()) } @@ -585,7 +581,6 @@ func TestRdsMultiAzClusterMySQLDatabaseDialect_GetHostListProvider(t *testing.T) property_util.PLUGINS.Set(propsNoFailover, "efm") hostListProvider := testDatabaseDialect.GetHostListProvider( propsNoFailover, - "dsn", nil, nil) @@ -597,7 +592,6 @@ func TestRdsMultiAzClusterMySQLDatabaseDialect_GetHostListProvider(t *testing.T) property_util.PLUGINS.Set(propsWithFailover, "failover") hostListProvider = testDatabaseDialect.GetHostListProvider( propsWithFailover, - "dsn", nil, nil) @@ -739,7 +733,7 @@ func TestRdsMultiAzClusterMySQLDatabaseDialect_GetWriterHostName(t *testing.T) { mockRows.EXPECT().Columns().Return([]string{"Something", "Source_Server_Id"}) mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { dest[0] = int64(123) - dest[1] = int64(hostId) + dest[1] = hostId return nil }) mockRows.EXPECT().Close().Return(nil) @@ -767,7 +761,7 @@ func TestRdsMultiAzClusterMySQLDatabaseDialect_GetWriterHostName(t *testing.T) { mockRows.EXPECT().Columns().Return([]string{"Something", "Source_Server_Id"}) mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { dest[0] = int64(123) - dest[1] = int64(hostId) + dest[1] = hostId return nil }) mockRows.EXPECT().Close().Return(nil) @@ -1062,3 +1056,385 @@ func TestMysqlGetSetTransactionIsolationQuery(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "set session transaction isolation level SERIALIZABLE", query) } + +func TestAuroraMySQLDatabaseDialect_GetBlueGreenStatus(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"version", "endpoint", "port", "role", "status"}) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[1] = []uint8("myapp-prod-db.c1a2b3c4d5e6.us-east-1.rds.amazonaws.com") + dest[2] = int64(3306) + dest[3] = []uint8("BLUE_GREEN_DEPLOYMENT_SOURCE") + dest[4] = []uint8("AVAILABLE") + dest[0] = []uint8("1.0") + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[1] = []uint8("myapp-prod-db-target.c1a2b3c4d5e6.us-east-1.rds.amazonaws.com") + dest[2] = int64(3306) + dest[3] = []uint8("BLUE_GREEN_DEPLOYMENT_TARGET") + dest[4] = []uint8("SWITCHOVER_INITIATED") + dest[0] = []uint8("1.1") + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + + assert.Len(t, results, 2) + + assert.Equal(t, "1.0", results[0].Version) + assert.Equal(t, "myapp-prod-db.c1a2b3c4d5e6.us-east-1.rds.amazonaws.com", results[0].Endpoint) + assert.Equal(t, 3306, results[0].Port) + assert.Equal(t, "BLUE_GREEN_DEPLOYMENT_SOURCE", results[0].Role) + assert.Equal(t, "AVAILABLE", results[0].Status) + + assert.Equal(t, "1.1", results[1].Version) + assert.Equal(t, "myapp-prod-db-target.c1a2b3c4d5e6.us-east-1.rds.amazonaws.com", results[1].Endpoint) + assert.Equal(t, 3306, results[1].Port) + assert.Equal(t, "BLUE_GREEN_DEPLOYMENT_TARGET", results[1].Role) + assert.Equal(t, "SWITCHOVER_INITIATED", results[1].Status) +} + +func TestAuroraMySQLDatabaseDialect_GetBlueGreenStatus_QueryError(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(nil, fmt.Errorf("table does not exist")) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + assert.Nil(t, results) +} + +func TestAuroraMySQLDatabaseDialect_GetBlueGreenStatus_NoQueryerContext(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + + results := testDatabaseDialect.GetBlueGreenStatus(mockConn) + assert.Nil(t, results) +} + +func TestAuroraMySQLDatabaseDialect_IsBlueGreenStatusAvailable(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"tmp"}) + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = int64(1) + return nil + }) + mockRows.EXPECT().Close().Return(nil) + + result := testDatabaseDialect.IsBlueGreenStatusAvailable(conn) + assert.True(t, result) + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{}) + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrBadConn) + mockRows.EXPECT().Close().Return(nil) + + result = testDatabaseDialect.IsBlueGreenStatusAvailable(conn) + assert.False(t, result) +} + +func TestRdsMySQLDatabaseDialect_GetBlueGreenStatus(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.RdsMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"version", "endpoint", "port", "role", "status"}) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[1] = []uint8("user-service-db.x7y8z9a1b2c3.eu-west-1.rds.amazonaws.com") + dest[2] = int64(3306) + dest[3] = []uint8("BLUE_GREEN_DEPLOYMENT_SOURCE") + dest[4] = []uint8("AVAILABLE") + dest[0] = []uint8("2.0") + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + + assert.Len(t, results, 1) + assert.Equal(t, "2.0", results[0].Version) + assert.Equal(t, "user-service-db.x7y8z9a1b2c3.eu-west-1.rds.amazonaws.com", results[0].Endpoint) + assert.Equal(t, 3306, results[0].Port) + assert.Equal(t, "BLUE_GREEN_DEPLOYMENT_SOURCE", results[0].Role) + assert.Equal(t, "AVAILABLE", results[0].Status) +} + +func TestRdsMySQLDatabaseDialect_GetBlueGreenStatus_EmptyResults(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.RdsMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"endpoint", "port", "role", "status", "version"}) + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) // No rows + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + assert.Empty(t, results) +} + +func TestRdsMySQLDatabaseDialect_IsBlueGreenStatusAvailable(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.RdsMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"tmp"}) + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = int64(1) + return nil + }) + mockRows.EXPECT().Close().Return(nil) + + result := testDatabaseDialect.IsBlueGreenStatusAvailable(conn) + assert.True(t, result) +} + +func TestRdsMySQLDatabaseDialect_IsBlueGreenStatusAvailable_QueryError(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.RdsMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'" + + // Test query error - should return false + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(nil, fmt.Errorf("connection error")) + + result := testDatabaseDialect.IsBlueGreenStatusAvailable(conn) + assert.False(t, result) +} + +func TestGetBlueGreenStatus_InvalidRowData(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"id", "endpoint", "port", "role", "status", "version"}) + + // Mock row with invalid data types (should be skipped) + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = []uint8("1") + dest[1] = "invalid_type" + dest[2] = int64(3306) + dest[3] = []uint8("BLUE_GREEN_DEPLOYMENT_SOURCE") + dest[4] = []uint8("AVAILABLE") + dest[5] = []uint8("1.0") + return nil + }) + + // Mock valid row + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = []uint8("2") + dest[1] = []uint8("valid-endpoint.amazonaws.com") + dest[2] = int64(3306) + dest[3] = []uint8("BLUE_GREEN_DEPLOYMENT_TARGET") + dest[4] = []uint8("AVAILABLE") + dest[5] = []uint8("1.0") + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + + // Should only return the valid row, invalid row should be skipped + assert.Len(t, results, 1) + assert.Equal(t, "valid-endpoint.amazonaws.com", results[0].Endpoint) +} + +func TestGetBlueGreenStatus_InsufficientColumns(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.RdsMySQLDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + // Mock only 3 columns instead of required 6 + mockRows.EXPECT().Columns().Return([]string{"id", "endpoint", "port"}) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = []uint8("1") + dest[1] = []uint8("endpoint.amazonaws.com") + dest[2] = int64(3306) + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + + assert.Empty(t, results) +} diff --git a/.test/test/pg_database_dialects_test.go b/.test/test/pg_database_dialects_test.go index d9146c42..50229a34 100644 --- a/.test/test/pg_database_dialects_test.go +++ b/.test/test/pg_database_dialects_test.go @@ -111,7 +111,6 @@ func TestPgDatabaseDialect_GetHostListProvider(t *testing.T) { testDatabaseDialect := &driver_infrastructure.PgDatabaseDialect{} hostListProvider := testDatabaseDialect.GetHostListProvider( make(map[string]string), - "dsn", nil, nil) @@ -238,7 +237,6 @@ func TestRdsPgDatabaseDialect_GetHostListProvider(t *testing.T) { testDatabaseDialect := &driver_infrastructure.RdsPgDatabaseDialect{} hostListProvider := testDatabaseDialect.GetHostListProvider( make(map[string]string), - "dsn", nil, nil) @@ -433,7 +431,6 @@ func TestAuroraRdsPgDatabaseDialect_GetHostListProvider(t *testing.T) { property_util.PLUGINS.Set(propsNoFailover, "efm") hostListProvider := testDatabaseDialect.GetHostListProvider( propsNoFailover, - "dsn", nil, nil) @@ -445,7 +442,6 @@ func TestAuroraRdsPgDatabaseDialect_GetHostListProvider(t *testing.T) { property_util.PLUGINS.Set(propsWithFailover, "failover") hostListProvider = testDatabaseDialect.GetHostListProvider( propsWithFailover, - "dsn", nil, nil) @@ -698,7 +694,7 @@ func TestAuroraRdsPgDatabaseDialect_GetLimitlessRouterEndpointQuery(t *testing.T func TestRdsMultiAzDbClusterPgDialect_GetDialectUpdateCandidates(t *testing.T) { testDatabaseDialect := &driver_infrastructure.RdsMultiAzClusterPgDatabaseDialect{} - expectedCandidates := []string{} + var expectedCandidates []string assert.ElementsMatch(t, expectedCandidates, testDatabaseDialect.GetDialectUpdateCandidates()) } @@ -731,7 +727,6 @@ func TestRdsMultiAzDbClusterPgDialect_GetHostListProvider(t *testing.T) { property_util.PLUGINS.Set(propsNoFailover, "efm") hostListProvider := testDatabaseDialect.GetHostListProvider( propsNoFailover, - "dsn", nil, nil) @@ -743,7 +738,6 @@ func TestRdsMultiAzDbClusterPgDialect_GetHostListProvider(t *testing.T) { property_util.PLUGINS.Set(propsWithFailover, "failover") hostListProvider = testDatabaseDialect.GetHostListProvider( propsWithFailover, - "dsn", nil, nil) @@ -1230,3 +1224,365 @@ func TestPgGetSetTransactionIsolationQuery(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "set session characteristics as transaction isolation level SERIALIZABLE", query) } + +func TestAuroraPgDatabaseDialect_GetBlueGreenStatus(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraPgDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM get_blue_green_fast_switchover_metadata(" + + "'aws_advanced_go_wrapper-" + driver_info.AWS_ADVANCED_GO_WRAPPER_VERSION + "')" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"id", "endpoint", "port", "role", "status", "version"}) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[1] = "prod-aurora-cluster.cluster-abc123def456.us-east-1.rds.amazonaws.com" // endpoint + dest[2] = int64(5432) // port + dest[3] = "BLUE_GREEN_DEPLOYMENT_SOURCE" // role + dest[4] = "AVAILABLE" // status + dest[0] = "1.0" // version + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[1] = "prod-aurora-cluster-target.cluster-xyz789def456.us-east-1.rds.amazonaws.com" // endpoint + dest[2] = int64(5432) // port + dest[3] = "BLUE_GREEN_DEPLOYMENT_TARGET" // role + dest[4] = "SWITCHOVER_IN_PROGRESS" // status + dest[0] = "1.1" // version + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + + assert.Len(t, results, 2) + + assert.Equal(t, "1.0", results[0].Version) + assert.Equal(t, "prod-aurora-cluster.cluster-abc123def456.us-east-1.rds.amazonaws.com", results[0].Endpoint) + assert.Equal(t, 5432, results[0].Port) + assert.Equal(t, "BLUE_GREEN_DEPLOYMENT_SOURCE", results[0].Role) + assert.Equal(t, "AVAILABLE", results[0].Status) + + assert.Equal(t, "1.1", results[1].Version) + assert.Equal(t, "prod-aurora-cluster-target.cluster-xyz789def456.us-east-1.rds.amazonaws.com", results[1].Endpoint) + assert.Equal(t, 5432, results[1].Port) + assert.Equal(t, "BLUE_GREEN_DEPLOYMENT_TARGET", results[1].Role) + assert.Equal(t, "SWITCHOVER_IN_PROGRESS", results[1].Status) +} + +func TestAuroraPgDatabaseDialect_GetBlueGreenStatus_QueryError(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraPgDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM get_blue_green_fast_switchover_metadata(" + + "'aws_advanced_go_wrapper-" + driver_info.AWS_ADVANCED_GO_WRAPPER_VERSION + "')" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(nil, fmt.Errorf("function does not exist")) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + assert.Nil(t, results) +} + +func TestAuroraPgDatabaseDialect_GetBlueGreenStatus_NoQueryerContext(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraPgDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + + results := testDatabaseDialect.GetBlueGreenStatus(mockConn) + assert.Nil(t, results) +} + +func TestAuroraPgDatabaseDialect_IsBlueGreenStatusAvailable(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraPgDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT 'get_blue_green_fast_switchover_metadata'::regproc" + + // Test when function exists (returns true) + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"regproc"}) + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = "get_blue_green_fast_switchover_metadata" + return nil + }) + mockRows.EXPECT().Close().Return(nil) + + result := testDatabaseDialect.IsBlueGreenStatusAvailable(conn) + assert.True(t, result) + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{}) + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrBadConn) + mockRows.EXPECT().Close().Return(nil) + + result = testDatabaseDialect.IsBlueGreenStatusAvailable(conn) + assert.False(t, result) +} + +func TestRdsPgDatabaseDialect_GetBlueGreenStatus(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.RdsPgDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM rds_tools.show_topology('aws_advanced_go_wrapper-1.0.0')" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"id", "endpoint", "port", "role", "status", "version"}) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = "2.0" + dest[1] = "analytics-postgres.s1t2u3v4w5x6.us-west-2.rds.amazonaws.com" + dest[2] = int64(5432) + dest[3] = "BLUE_GREEN_DEPLOYMENT_SOURCE" + dest[4] = "AVAILABLE" + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + + assert.Len(t, results, 1) + assert.Equal(t, "2.0", results[0].Version) + assert.Equal(t, "analytics-postgres.s1t2u3v4w5x6.us-west-2.rds.amazonaws.com", results[0].Endpoint) + assert.Equal(t, 5432, results[0].Port) + assert.Equal(t, "BLUE_GREEN_DEPLOYMENT_SOURCE", results[0].Role) + assert.Equal(t, "AVAILABLE", results[0].Status) +} + +func TestRdsPgDatabaseDialect_GetBlueGreenStatus_EmptyResults(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.RdsPgDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM rds_tools.show_topology('aws_advanced_go_wrapper-1.0.0')" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"id", "endpoint", "port", "role", "status", "version"}) + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) // No rows + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + assert.Empty(t, results) +} + +func TestRdsPgDatabaseDialect_IsBlueGreenStatusAvailable(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.RdsPgDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT 'rds_tools.show_topology'::regproc" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"regproc"}) + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = "rds_tools.show_topology" + return nil + }) + mockRows.EXPECT().Close().Return(nil) + + result := testDatabaseDialect.IsBlueGreenStatusAvailable(conn) + assert.True(t, result) + + // Test query error - should return false + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(nil, fmt.Errorf("connection error")) + + result = testDatabaseDialect.IsBlueGreenStatusAvailable(conn) + assert.False(t, result) +} + +func TestPgGetBlueGreenStatus_InvalidRowData(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.AuroraPgDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM get_blue_green_fast_switchover_metadata(" + + "'aws_advanced_go_wrapper-" + driver_info.AWS_ADVANCED_GO_WRAPPER_VERSION + "')" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + mockRows.EXPECT().Columns().Return([]string{"id", "endpoint", "port", "role", "status", "version"}) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[1] = 12345 + dest[2] = int64(5432) + dest[3] = "BLUE_GREEN_DEPLOYMENT_SOURCE" + dest[4] = "AVAILABLE" + dest[0] = "1.0" + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[1] = "valid-endpoint.amazonaws.com" + dest[2] = int64(5432) + dest[3] = "BLUE_GREEN_DEPLOYMENT_TARGET" + dest[4] = "AVAILABLE" + dest[0] = "1.0" + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + + assert.Len(t, results, 1) + assert.Equal(t, "valid-endpoint.amazonaws.com", results[0].Endpoint) +} + +func TestPgGetBlueGreenStatus_InsufficientColumns(t *testing.T) { + testDatabaseDialect := &driver_infrastructure.RdsPgDatabaseDialect{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + conn := struct { + driver.Conn + driver.QueryerContext + }{ + Conn: mockConn, + QueryerContext: mockQueryer, + } + + expectedQuery := "SELECT version, endpoint, port, role, status FROM rds_tools.show_topology('aws_advanced_go_wrapper-1.0.0')" + + mockQueryer.EXPECT(). + QueryContext(gomock.Any(), expectedQuery, gomock.Nil()). + Return(mockRows, nil) + + // Mock only 3 columns instead of required 6 + mockRows.EXPECT().Columns().Return([]string{"id", "endpoint", "port"}) + + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = "1" + dest[1] = "endpoint.amazonaws.com" + dest[2] = int64(5432) + return nil + }) + + mockRows.EXPECT().Next(gomock.Any()).Return(driver.ErrSkip) + mockRows.EXPECT().Close().Return(nil) + + results := testDatabaseDialect.GetBlueGreenStatus(conn) + + assert.Empty(t, results) +} diff --git a/.test/test/plugin_manager_benchmark_test.go b/.test/test/plugin_manager_benchmark_test.go index f0e77189..228437d7 100644 --- a/.test/test/plugin_manager_benchmark_test.go +++ b/.test/test/plugin_manager_benchmark_test.go @@ -25,6 +25,7 @@ import ( "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" "github.com/aws/aws-advanced-go-wrapper/awssql/plugin_helpers" "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" ) @@ -107,7 +108,7 @@ func BenchmarkExecute(b *testing.B) { } func BenchmarkInitHostProvider(b *testing.B) { - props := make(map[string]string) + props, _ := utils.ParseDsn(mysqlTestDsn) for _, count := range PLUGIN_COUNTS { count := count // capture range variable @@ -118,7 +119,6 @@ func BenchmarkInitHostProvider(b *testing.B) { for i := 0; i < b.N; i++ { //nolint:errcheck pluginManager.InitHostProvider( - mysqlTestDsn, props, &MockRdsHostListProviderService{}, ) diff --git a/.test/test/plugin_manager_test.go b/.test/test/plugin_manager_test.go index 9a107686..3b63b438 100644 --- a/.test/test/plugin_manager_test.go +++ b/.test/test/plugin_manager_test.go @@ -30,6 +30,7 @@ import ( "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" "github.com/aws/aws-advanced-go-wrapper/awssql/plugin_helpers" "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" + "github.com/aws/aws-advanced-go-wrapper/iam" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -608,3 +609,45 @@ func TestConnectPluginToSkip(t *testing.T) { require.NotNil(t, err) assert.Equal(t, error_util.NewGenericAwsWrapperError(error_util.GetMessage("PluginManager.pipelineNone")), err) } + +func TestIsPluginInUse(t *testing.T) { + mockTargetDriver := &MockTargetDriver{} + props := make(map[string]string) + connectionProviderManager := driver_infrastructure.ConnectionProviderManager{} + telemetryFactory, _ := telemetry.NewDefaultTelemetryFactory(props) + pluginManager := plugin_helpers.NewPluginManagerImpl(mockTargetDriver, props, connectionProviderManager, telemetryFactory) + pluginService := driver_infrastructure.PluginService(&plugin_helpers.PluginServiceImpl{}) + + assert.False(t, pluginManager.IsPluginInUse(testPluginCode), "Should return false when no plugins are loaded") + assert.False(t, pluginManager.IsPluginInUse("nonexistentPlugin"), "Should return false for nonexistent plugin") + + var calls []string + plugins := []driver_infrastructure.ConnectionPlugin{ + CreateTestPlugin(&calls, 1, nil, nil, false), + CreateTestPlugin(&calls, 2, nil, nil, false), + CreateTestPlugin(&calls, 3, nil, nil, false), + &iam.IamAuthPlugin{}, + } + + err := pluginManager.Init(pluginService, plugins) + require.Nil(t, err) + + assert.True(t, pluginManager.IsPluginInUse(testPluginCode), "Should return true when TestPlugin is loaded") + assert.True(t, pluginManager.IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE), "Should return true when iam plugin is loaded") + + assert.False(t, pluginManager.IsPluginInUse("nonexistentPlugin"), "Should return false for non-existent plugin type") + assert.False(t, pluginManager.IsPluginInUse("default"), "Should return false for DefaultPlugin when not loaded") + assert.False(t, pluginManager.IsPluginInUse(driver_infrastructure.FAILOVER_PLUGIN_CODE), "Should return false for FailoverPlugin when not loaded") + + assert.False(t, pluginManager.IsPluginInUse("Test"), "Should return false for case-sensitive mismatch") + assert.False(t, pluginManager.IsPluginInUse(" test"), "Should return false when there is additional spacing") + assert.False(t, pluginManager.IsPluginInUse("tes"), "Should return false for partial name match") + assert.False(t, pluginManager.IsPluginInUse(""), "Should return false for empty string") + + // Re-initialize with empty plugin list + err = pluginManager.Init(pluginService, []driver_infrastructure.ConnectionPlugin{}) + require.Nil(t, err) + + // Should no longer find the plugin + assert.False(t, pluginManager.IsPluginInUse("*test.TestPlugin"), "Should return false after plugins are removed") +} diff --git a/.test/test/plugin_service_test.go b/.test/test/plugin_service_test.go index 600bfcc8..0ce4f800 100644 --- a/.test/test/plugin_service_test.go +++ b/.test/test/plugin_service_test.go @@ -18,12 +18,14 @@ package test import ( "database/sql/driver" + "testing" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" "github.com/aws/aws-advanced-go-wrapper/awssql/plugin_helpers" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" - "github.com/aws/aws-advanced-go-wrapper/pgx-driver" - "testing" + pgx_driver "github.com/aws/aws-advanced-go-wrapper/pgx-driver" "github.com/stretchr/testify/assert" ) @@ -348,3 +350,82 @@ func TestFillAliasesNonEmptyAliases(t *testing.T) { target.FillAliases(MockDriverConn{}, hostA) assert.Equal(t, 1, len(hostA.Aliases)) } + +func TestGetStatusEmptyCache(t *testing.T) { + target, _, _, err := beforePluginServiceTests() + assert.Nil(t, err) + + plugin_helpers.ClearCaches() + + status, found := target.GetBgStatus("test-id") + assert.False(t, found, "Should return false when status not found in cache") + assert.True(t, status.IsZero(), "Should return zero BlueGreenStatus when not found") +} + +func TestGetSetStatus(t *testing.T) { + target, _, _, err := beforePluginServiceTests() + assert.Nil(t, err) + + plugin_helpers.ClearCaches() + + target.SetBgStatus(driver_infrastructure.NewBgStatus( + "test-id", + driver_infrastructure.IN_PROGRESS, + []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, + utils.NewRWMap[driver_infrastructure.BlueGreenRole](), + utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + ), "test-bg") + + // Try to retrieve with different formats + testCases := []string{ + "test-bg", + "TEST-BG", + "Test-Bg", + "TeSt-Bg", + " test-bg", + "test-bg ", + " test-bg ", + "\ttest-bg\t", + "\ntest-bg\n", + } + + for _, testId := range testCases { + retrievedStatus, found := target.GetBgStatus(testId) + assert.True(t, found, "Should find status regardless of case for ID: %s", testId) + assert.False(t, retrievedStatus.IsZero(), "Status should not be zero for ID: %s", testId) + assert.Equal(t, driver_infrastructure.IN_PROGRESS, retrievedStatus.GetCurrentPhase(), "Should retrieve correct phase for ID: %s", testId) + } +} + +func TestSetStatusUpdateExistingStatus(t *testing.T) { + target, _, _, err := beforePluginServiceTests() + assert.Nil(t, err) + + plugin_helpers.ClearCaches() + + testId := "deployment-update-test" + connectRoutings := []driver_infrastructure.ConnectRouting{} + executeRoutings := []driver_infrastructure.ExecuteRouting{} + roleByHost := utils.NewRWMap[driver_infrastructure.BlueGreenRole]() + correspondingHosts := utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]]() + + initialStatus := driver_infrastructure.NewBgStatus(testId, driver_infrastructure.CREATED, connectRoutings, executeRoutings, roleByHost, correspondingHosts) + target.SetBgStatus(initialStatus, testId) + + retrievedStatus, found := target.GetBgStatus(testId) + assert.True(t, found, "Should find initial status") + assert.Equal(t, driver_infrastructure.CREATED, retrievedStatus.GetCurrentPhase(), "Should have initial phase") + + updatedStatus := driver_infrastructure.NewBgStatus(testId, driver_infrastructure.COMPLETED, connectRoutings, executeRoutings, roleByHost, correspondingHosts) + target.SetBgStatus(updatedStatus, testId) + + retrievedStatus, found = target.GetBgStatus(testId) + assert.True(t, found, "Should find updated status") + assert.Equal(t, driver_infrastructure.COMPLETED, retrievedStatus.GetCurrentPhase(), "Should have updated phase") + + target.SetBgStatus(driver_infrastructure.BlueGreenStatus{}, testId) + + _, found = target.GetBgStatus(testId) + assert.False(t, found, "Should remove status") +} diff --git a/.test/test/rds_host_list_provider_test.go b/.test/test/rds_host_list_provider_test.go index 5c9df3fc..fdb44d5c 100644 --- a/.test/test/rds_host_list_provider_test.go +++ b/.test/test/rds_host_list_provider_test.go @@ -18,11 +18,13 @@ package test import ( "database/sql/driver" - "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" - "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" "strings" "testing" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + "github.com/stretchr/testify/assert" ) @@ -33,15 +35,13 @@ var mockPgAuroraDialect = &driver_infrastructure.AuroraPgDatabaseDialect{} func beforePgTests() *driver_infrastructure.RdsHostListProvider { driver_infrastructure.ClearAllRdsHostListProviderCaches() - var mockPgProps = map[string]string{"clusterId": "pg_cluster"} - var mockPgDsn = "postgres://someUser:somePassword@localhost:5432/pgx_test?sslmode=disable&foo=bar" - return driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, mockPgProps, mockPgDsn, nil, nil) + mockPgProps, _ := utils.ParseDsn("postgres://someUser:somePassword@localhost:5432/pgx_test?sslmode=disable&foo=bar&clusterId=pg_cluster") + return driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, mockPgProps, nil, nil) } func beforeMySqlTests() *driver_infrastructure.RdsHostListProvider { driver_infrastructure.ClearAllRdsHostListProviderCaches() - mockMySQLProps := map[string]string{"clusterId": "mysql_cluster"} - mockMySQLDsn := "someUser:somePassword@tcp(mydatabase.com:3306)/myDatabase?foo=bar&pop=snap" - return driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, &driver_infrastructure.AuroraMySQLDatabaseDialect{}, mockMySQLProps, mockMySQLDsn, nil, nil) + mockMySQLProps, _ := utils.ParseDsn("someUser:somePassword@tcp(mydatabase.com:3306)/myDatabase?foo=bar&pop=snap&clusterId=mysql_cluster") + return driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, &driver_infrastructure.AuroraMySQLDatabaseDialect{}, mockMySQLProps, nil, nil) } func TestGetClusterId(t *testing.T) { @@ -239,9 +239,8 @@ func TestMySQLIdentifyConnection(t *testing.T) { func TestSuggestedClusterIdForRds(t *testing.T) { driver_infrastructure.ClearAllRdsHostListProviderCaches() - - dsn := "postgresql://user:password@name.cluster-xyz.us-east-2.rds.amazonaws.com:5432/database" - provider1 := driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, emptyProps, dsn, nil, nil) + props, _ := utils.ParseDsn("postgresql://user:password@name.cluster-xyz.us-east-2.rds.amazonaws.com:5432/database") + provider1 := driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, props, nil, nil) mockConn := MockConn{} mockConn.updateQueryRowSingleUse([]string{"hostName", "isWriter", "cpu", "lag", "lastUpdateTime"}, []driver.Value{"instance-a-1", true, 1.0, 2.0, 0}) @@ -251,7 +250,7 @@ func TestSuggestedClusterIdForRds(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "instance-a-1.xyz.us-east-2.rds.amazonaws.com", hosts[0].Host) - provider2 := driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, emptyProps, dsn, nil, nil) + provider2 := driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, props, nil, nil) actualClusterId1, err1 := provider1.GetClusterId() actualClusterId2, err2 := provider2.GetClusterId() assert.Equal(t, actualClusterId1, actualClusterId2) @@ -268,9 +267,8 @@ func TestSuggestedClusterIdForRds(t *testing.T) { func TestNoSuggestedClusterId(t *testing.T) { driver_infrastructure.ClearAllRdsHostListProviderCaches() - - dsn1 := "postgresql://user:password@name1.cluster-xyz.us-east-2.rds.amazonaws.com:5432/database" - provider1 := driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, emptyProps, dsn1, nil, nil) + props, _ := utils.ParseDsn("postgresql://user:password@name1.cluster-xyz.us-east-2.rds.amazonaws.com:5432/database") + provider1 := driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, props, nil, nil) mockConn := MockConn{} mockConn.updateQueryRowSingleUse([]string{"hostName", "isWriter", "cpu", "lag", "lastUpdateTime"}, []driver.Value{"instance-a-1", true, 1.0, 2.0, 0}) @@ -280,8 +278,8 @@ func TestNoSuggestedClusterId(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "instance-a-1.xyz.us-east-2.rds.amazonaws.com", hosts[0].Host) - dsn2 := "postgresql://user:password@name2.cluster-xyz.us-east-2.rds.amazonaws.com:5432/database" - provider2 := driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, emptyProps, dsn2, nil, nil) + props, _ = utils.ParseDsn("postgresql://user:password@name2.cluster-xyz.us-east-2.rds.amazonaws.com:5432/database") + provider2 := driver_infrastructure.NewRdsHostListProvider(mockHostListProviderService, mockPgAuroraDialect, props, nil, nil) mockConn = MockConn{} mockConn.updateQueryRowSingleUse([]string{"hostName", "isWriter", "cpu", "lag", "lastUpdateTime"}, []driver.Value{"instance-b-1", true, 1.0, 2.0, 0}) diff --git a/.test/test/rds_utils_test.go b/.test/test/rds_utils_test.go index 08d55a1b..37b752f7 100644 --- a/.test/test/rds_utils_test.go +++ b/.test/test/rds_utils_test.go @@ -50,9 +50,9 @@ const ( oldChinaRegionLimitlessDbShardGroup = "database-test-name.shardgrp-XYZ.cn-northwest-1.rds.amazonaws.com.cn" oldChinaRegionLimitlessDbShardGroupTrailingDot = "database-test-name.shardgrp-XYZ.cn-northwest-1.rds.amazonaws.com.cn." - extraRdsChinaPath = "database-test-name.cluster-XYZ.rds.cn-northwest-1.rds.amazonaws.com.cn" //nolint:unused - missingCnChinaPath = "database-test-name.cluster-XYZ.rds.cn-northwest-1.amazonaws.com" //nolint:unused - missingRegionChinaPath = "database-test-name.cluster-XYZ.rds.amazonaws.com.cn" //nolint:unused + extraRdsChinaPath = "database-test-name.cluster-XYZ.rds.cn-northwest-1.rds.amazonaws.com.cn" + missingCnChinaPath = "database-test-name.cluster-XYZ.rds.cn-northwest-1.amazonaws.com" + missingRegionChinaPath = "database-test-name.cluster-XYZ.rds.amazonaws.com.cn" usEastRegionElbUrl = "elb-name.elb.us-east-2.amazonaws.com" usEastRegionElbUrlTrailingDot = "elb-name.elb.us-east-2.amazonaws.com." @@ -72,6 +72,10 @@ const ( usIsoEastRegionProxy = "proxy-test-name.proxy-XYZ.rds.us-iso-east-1.c2s.ic.gov" usIsoEastRegionCustomDomain = "custom-test-name.cluster-custom-XYZ.rds.us-iso-east-1.c2s.ic.gov" usIsoEastRegionLimitlessDbShardGroup = "database-test-name.shardgrp-XYZ.rds.us-iso-east-1.c2s.ic.gov" + + blueInstance = "myapp-blue.abc123.us-east-1.rds.amazonaws.com" + greenInstance = "myapp-green-abc123.def456.us-east-1.rds.amazonaws.com" + oldInstance = "myapp-old1.ghi789.us-east-1.rds.amazonaws.com" ) func TestIsRdsDns(t *testing.T) { @@ -107,6 +111,9 @@ func TestIsRdsDns(t *testing.T) { assert.True(t, utils.IsRdsDns(chinaRegionProxy)) assert.True(t, utils.IsRdsDns(chinaRegionCustomDomain)) assert.True(t, utils.IsRdsDns(chinaRegionLimitlessDbShardGroup)) + assert.True(t, utils.IsRdsDns(extraRdsChinaPath)) + assert.True(t, utils.IsRdsDns(missingCnChinaPath)) + assert.False(t, utils.IsRdsDns(missingRegionChinaPath)) assert.True(t, utils.IsRdsDns(oldChinaRegionCluster)) assert.True(t, utils.IsRdsDns(oldChinaRegionClusterTrailingDot)) @@ -269,3 +276,250 @@ func TestGetRdsClusterHostUrl(t *testing.T) { assert.Equal(t, "database-test-name.cluster-XYZ.rds.us-isob-east-1.sc2s.sgov.gov", utils.GetRdsClusterHostUrl(usIsobEastRegionCluster)) assert.Equal(t, "database-test-name.cluster-XYZ.rds.us-iso-east-1.c2s.ic.gov", utils.GetRdsClusterHostUrl(usIsoEastRegionCluster)) } + +func TestIsRdsInstance(t *testing.T) { + assert.True(t, utils.IsRdsInstance(usEastRegionInstance), "Should identify RDS instance") + assert.True(t, utils.IsRdsInstance(chinaRegionInstance), "Should identify China RDS instance") + assert.True(t, utils.IsRdsInstance(oldChinaRegionInstance), "Should identify old China RDS instance") + + assert.False(t, utils.IsRdsInstance(usEastRegionCluster), "Should not identify cluster as instance") + assert.False(t, utils.IsRdsInstance(usEastRegionClusterReadOnly), "Should not identify read-only cluster as instance") + assert.False(t, utils.IsRdsInstance(chinaRegionCluster), "Should not identify China cluster as instance") + + assert.False(t, utils.IsRdsInstance(usEastRegionProxy), "Should not identify proxy as instance") + assert.False(t, utils.IsRdsInstance(chinaRegionProxy), "Should not identify China proxy as instance") + + assert.False(t, utils.IsRdsInstance("example.com"), "Should not identify non-RDS host as instance") + assert.False(t, utils.IsRdsInstance("database.example.org"), "Should not identify non-RDS host as instance") + assert.False(t, utils.IsRdsInstance(""), "Should not identify empty string as instance") + + assert.True(t, utils.IsRdsInstance("instance-test-name.XYZ.us-east-2.rds.amazonaws.com."), "Should handle trailing dot") + + assert.True(t, utils.IsRdsInstance(blueInstance), "Should identify blue instance as RDS instance") + assert.True(t, utils.IsRdsInstance(greenInstance), "Should identify green instance as RDS instance") + assert.True(t, utils.IsRdsInstance(oldInstance), "Should identify old instance as RDS instance") +} + +func TestIsGreenInstance(t *testing.T) { + greenInstances := []string{ + "myapp-green-abc123.def456.us-east-1.rds.amazonaws.com", + "database-green-xyz789.cluster-abc123.us-west-2.rds.amazonaws.com", + "test-green-123abc.instance.eu-west-1.rds.amazonaws.com", + "prod-GREEN-456DEF.cluster.ap-southeast-1.rds.amazonaws.com", // case insensitive + } + + for _, host := range greenInstances { + assert.True(t, utils.IsGreenInstance(host), "Should identify green instance: %s", host) + } + + nonGreenInstances := []string{ + blueInstance, + oldInstance, + "test-instance.xyz789.eu-west-1.rds.amazonaws.com", + "prod-cluster.cluster-abc123.ap-southeast-1.rds.amazonaws.com", + "example.com", + "", + "myapp-greenish.abc123.us-east-1.rds.amazonaws.com", + "myapp-green.abc123.us-east-1.rds.amazonaws.com", + } + for _, host := range nonGreenInstances { + assert.False(t, utils.IsGreenInstance(host), "Should not identify as green instance: %s", host) + } + + assert.False(t, utils.IsGreenInstance(""), "Should return false for empty string") + assert.False(t, utils.IsGreenInstance(" "), "Should return false for whitespace") +} + +func TestIsNotOldInstance(t *testing.T) { + nonOldInstances := []string{ + "myapp-old1ish.abc123.us-east-1.rds.amazonaws.com", + "prod-old1-cluster.cluster-abc123.ap-southeast-1.rds.amazonaws.com", + "myapp-blue.abc123.us-east-1.rds.amazonaws.com", + "database-green-xyz789.def456.us-west-2.rds.amazonaws.com", + "test-instance.ghi789.eu-west-1.rds.amazonaws.com", + "prod-cluster.cluster-abc123.ap-southeast-1.rds.amazonaws.com", + "example.com", + "", + " ", + } + + for _, host := range nonOldInstances { + assert.True(t, utils.IsNotOldInstance(host), "Should identify as not old instance: %s", host) + } + + oldInstances := []string{ + "myapp-old1.abc123.us-east-1.rds.amazonaws.com", + "database-old1.def456.us-west-2.rds.amazonaws.com", + "test-OLD1.ghi789.eu-west-1.rds.amazonaws.com", + } + + for _, host := range oldInstances { + assert.False(t, utils.IsNotOldInstance(host), "Should identify as old instance: %s", host) + } +} + +func TestIsNotGreenAndNotOldInstance(t *testing.T) { + blueInstances := []string{ + blueInstance, + "database-prod.def456.us-west-2.rds.amazonaws.com", + "test-instance.ghi789.eu-west-1.rds.amazonaws.com", + "prod-cluster.cluster-abc123.ap-southeast-1.rds.amazonaws.com", + } + + for _, host := range blueInstances { + assert.True(t, utils.IsNotGreenAndNotOldInstance(host), "Should identify not green and not old: %s", host) + } + + greenInstances := []string{ + greenInstance, + "database-green-xyz789.cluster-abc123.us-west-2.rds.amazonaws.com", + } + + for _, host := range greenInstances { + assert.False(t, utils.IsNotGreenAndNotOldInstance(host), "Should identify as green instance: %s", host) + } + + oldInstances := []string{ + oldInstance, + "database-old1.def456.us-west-2.rds.amazonaws.com", + } + + for _, host := range oldInstances { + assert.False(t, utils.IsNotGreenAndNotOldInstance(host), "Should identify as old instance: %s", host) + } + + assert.False(t, utils.IsNotGreenAndNotOldInstance(""), "Should return false for empty string") +} + +func TestRemoveGreenInstancePrefix(t *testing.T) { + testCases := []struct { + input string + expected string + desc string + }{ + { + input: "myapp-green-abc123.def456.us-east-1.rds.amazonaws.com", + expected: "myapp.def456.us-east-1.rds.amazonaws.com", + desc: "Should remove green prefix from standard green instance", + }, + { + input: "database-green-xyz789.cluster-abc123.us-west-2.rds.amazonaws.com", + expected: "database.cluster-abc123.us-west-2.rds.amazonaws.com", + desc: "Should remove green prefix from green cluster", + }, + { + input: "test-GREEN-123ABC.instance.eu-west-1.rds.amazonaws.com", + expected: "test.instance.eu-west-1.rds.amazonaws.com", + desc: "Should handle case insensitive green prefix", + }, + } + + for _, tc := range testCases { + result := utils.RemoveGreenInstancePrefix(tc.input) + assert.Equal(t, tc.expected, result, tc.desc) + } + + nonGreenInstances := []string{ + "myapp-blue.abc123.us-east-1.rds.amazonaws.com", + "database-old1.def456.us-west-2.rds.amazonaws.com", + "test-instance.ghi789.eu-west-1.rds.amazonaws.com", + "example.com", + "", + } + + for _, host := range nonGreenInstances { + result := utils.RemoveGreenInstancePrefix(host) + assert.Equal(t, host, result, "Should return unchanged for non-green instance: %s", host) + } + + assert.Equal(t, "", utils.RemoveGreenInstancePrefix(""), "Should handle empty string") + + // Test hostid pattern fallback + hostIdPattern := "myapp-green-abc123" + expectedHostId := "myapp" + result := utils.RemoveGreenInstancePrefix(hostIdPattern) + assert.Equal(t, expectedHostId, result) +} + +func TestGetRdsClusterId(t *testing.T) { + testCases := []struct { + input string + expected string + desc string + }{ + { + input: usEastRegionCluster, + expected: "database-test-name", + desc: "Should extract cluster ID from US East cluster", + }, + { + input: usEastRegionClusterReadOnly, + expected: "database-test-name", + desc: "Should extract cluster ID from read-only cluster", + }, + { + input: chinaRegionCluster, + expected: "database-test-name", + desc: "Should extract cluster ID from China cluster", + }, + { + input: oldChinaRegionCluster, + expected: "database-test-name", + desc: "Should extract cluster ID from old China cluster", + }, + { + input: usEastRegionClusterTrailingDot, + expected: "database-test-name", + desc: "Should handle trailing dot", + }, + { + input: usEastRegionCustomDomain, + expected: "custom-test-name", + desc: "Should extract cluster ID from custom domain", + }, + { + input: usEastRegionInstance, + expected: "instance-test-name", + desc: "Should extract instance ID from US East instance", + }, + { + input: chinaRegionInstance, + expected: "instance-test-name", + desc: "Should extract instance ID from China instance", + }, + { + input: "myapp-blue.cluster-abc123.us-east-1.rds.amazonaws.com", + expected: "myapp-blue", + desc: "Should extract ID from blue cluster", + }, + { + input: "myapp-green-def456.cluster-xyz789.us-west-2.rds.amazonaws.com", + expected: "myapp-green-def456", + desc: "Should extract ID from green cluster", + }, + { + input: "myapp-old1.cluster-ghi789.eu-west-1.rds.amazonaws.com", + expected: "myapp-old1", + desc: "Should extract ID from old cluster", + }, + } + + for _, tc := range testCases { + result := utils.GetRdsClusterId(tc.input) + assert.Equal(t, tc.expected, result, tc.desc) + } + + nonRdsHosts := []string{ + "example.com", + "database.example.org", + "localhost", + "192.168.1.1", + " ", + "", + } + + for _, host := range nonRdsHosts { + result := utils.GetRdsClusterId(host) + assert.Equal(t, "", result, "Should return empty string for non-RDS host: %s", host) + } +} diff --git a/.test/test/read_write_splitting_plugin_test.go b/.test/test/read_write_splitting_plugin_test.go index c0a969cb..84e7f93a 100644 --- a/.test/test/read_write_splitting_plugin_test.go +++ b/.test/test/read_write_splitting_plugin_test.go @@ -93,7 +93,7 @@ func TestReadWriteSplittingPlugin_InitHostProvider(t *testing.T) { plugin := read_write_splitting.NewReadWriteSplittingPlugin(mockPluginService, nil) called := false - err := plugin.InitHostProvider("url", nil, mockHostProvider, func() error { + err := plugin.InitHostProvider(nil, mockHostProvider, func() error { called = true return nil }) @@ -136,7 +136,7 @@ func TestReadWriteSplittingPlugin_Connect_StaticProvider(t *testing.T) { mockHostProvider.EXPECT().IsStaticHostListProvider().Return(true) plugin := read_write_splitting.NewReadWriteSplittingPlugin(mockPluginService, nil) - _ = plugin.InitHostProvider("url", nil, mockHostProvider, func() error { return nil }) + _ = plugin.InitHostProvider(nil, mockHostProvider, func() error { return nil }) mockConn := mock_database_sql_driver.NewMockConn(ctrl) resultConn := mockConn @@ -160,7 +160,7 @@ func TestReadWriteSplittingPlugin_Connect_UnknownHostRole(t *testing.T) { mockPluginService.EXPECT().GetHostRole(gomock.Any()).Return(host_info_util.UNKNOWN) plugin := read_write_splitting.NewReadWriteSplittingPlugin(mockPluginService, nil) - _ = plugin.InitHostProvider("url", nil, mockHostProvider, func() error { return nil }) + _ = plugin.InitHostProvider(nil, mockHostProvider, func() error { return nil }) mockConn := mock_database_sql_driver.NewMockConn(ctrl) resultConn := mockConn @@ -188,7 +188,7 @@ func TestReadWriteSplittingPlugin_Connect_NilHostRole(t *testing.T) { mockPluginService.EXPECT().GetInitialConnectionHostInfo().Return(nil) plugin := read_write_splitting.NewReadWriteSplittingPlugin(mockPluginService, nil) - _ = plugin.InitHostProvider("url", nil, mockHostProvider, func() error { return nil }) + _ = plugin.InitHostProvider(nil, mockHostProvider, func() error { return nil }) mockConn := mock_database_sql_driver.NewMockConn(ctrl) resultConn := mockConn @@ -214,7 +214,7 @@ func TestReadWriteSplittingPlugin_Connect_SameHostRole(t *testing.T) { &host_info_util.HostInfo{Role: host_info_util.READER}) plugin := read_write_splitting.NewReadWriteSplittingPlugin(mockPluginService, nil) - _ = plugin.InitHostProvider("url", nil, mockHostProvider, func() error { return nil }) + _ = plugin.InitHostProvider(nil, mockHostProvider, func() error { return nil }) mockConn := mock_database_sql_driver.NewMockConn(ctrl) resultConn := mockConn @@ -241,7 +241,7 @@ func TestReadWriteSplittingPlugin_Connect_DifferentHostRole(t *testing.T) { mockHostProvider.EXPECT().SetInitialConnectionHostInfo(gomock.Any()).Return() plugin := read_write_splitting.NewReadWriteSplittingPlugin(mockPluginService, nil) - _ = plugin.InitHostProvider("url", nil, mockHostProvider, func() error { return nil }) + _ = plugin.InitHostProvider(nil, mockHostProvider, func() error { return nil }) mockConn := mock_database_sql_driver.NewMockConn(ctrl) resultConn := mockConn diff --git a/.test/test/rw_map_test.go b/.test/test/rw_map_test.go new file mode 100644 index 00000000..740aee05 --- /dev/null +++ b/.test/test/rw_map_test.go @@ -0,0 +1,324 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package test + +import ( + "fmt" + "sync" + "testing" + + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + + "github.com/stretchr/testify/assert" +) + +func TestNewRWMap(t *testing.T) { + rwMap := utils.NewRWMap[int]() + assert.NotNil(t, rwMap, "NewRWMap should return a non-nil map") + assert.Equal(t, 0, rwMap.Size(), "New map should be empty") +} + +func TestRWMapPutAndGet(t *testing.T) { + rwMap := utils.NewRWMap[string]() + + rwMap.Put("key1", "value1") + value, ok := rwMap.Get("key1") + assert.True(t, ok, "Should find the key that was put") + assert.Equal(t, "value1", value, "Should return the correct value") + + value, ok = rwMap.Get("nonexistent") + assert.False(t, ok, "Should not find non-existent key") + assert.Equal(t, "", value, "Should return zero value for non-existent key") + + rwMap.Put("key1", "newvalue1") + value, ok = rwMap.Get("key1") + assert.True(t, ok, "Should find the key after overwrite") + assert.Equal(t, "newvalue1", value, "Should return the new value") + + // Test empty key + rwMap.Put("", "empty-key-value") + value, ok = rwMap.Get("") + assert.True(t, ok, "Should handle empty key") + assert.Equal(t, "empty-key-value", value, "Should store and retrieve empty key value") + + // Test nil value (for pointer types) + ptrMap := utils.NewRWMap[*string]() + ptrMap.Put("nil-value", nil) + ptrValue, ok := ptrMap.Get("nil-value") + assert.True(t, ok, "Should handle nil pointer value") + assert.Nil(t, ptrValue, "Should store and retrieve nil pointer") + + // Test zero value + intMap := utils.NewRWMap[int]() + intMap.Put("zero", 0) + intValue, ok := intMap.Get("zero") + assert.True(t, ok, "Should handle zero value") + assert.Equal(t, 0, intValue, "Should store and retrieve zero value") + + // Test very long key + longKey := string(make([]byte, 10000)) + for i := range longKey { + longKey = longKey[:i] + "a" + longKey[i+1:] + } + rwMap.Put(longKey, "long-key-value") + value, ok = rwMap.Get(longKey) + assert.True(t, ok, "Should handle very long key") + assert.Equal(t, "long-key-value", value, "Should store and retrieve long key value") +} + +func TestRWMapComputeIfAbsent(t *testing.T) { + rwMap := utils.NewRWMap[int]() + + computeCallCount := 0 + computeFunc := func() int { + computeCallCount++ + return 100 + } + + value := rwMap.ComputeIfAbsent("key1", computeFunc) + assert.Equal(t, 100, value, "Should return computed value") + assert.Equal(t, 1, computeCallCount, "Compute function should be called once") + + storedValue, ok := rwMap.Get("key1") + assert.True(t, ok, "Computed value should be stored") + assert.Equal(t, 100, storedValue, "Stored value should match computed value") + + value = rwMap.ComputeIfAbsent("key1", computeFunc) + assert.Equal(t, 100, value, "Should return existing value") + assert.Equal(t, 1, computeCallCount, "Compute function should not be called again") + + value = rwMap.ComputeIfAbsent("key2", func() int { return 200 }) + assert.Equal(t, 200, value, "Should return new computed value") + assert.Equal(t, 2, rwMap.Size(), "Map should contain both keys") +} + +func TestRWMapPutIfAbsent(t *testing.T) { + rwMap := utils.NewRWMap[string]() + + rwMap.PutIfAbsent("key1", "value1") + value, ok := rwMap.Get("key1") + assert.True(t, ok, "Should find the key that was put") + assert.Equal(t, "value1", value, "Should return the correct value") + + rwMap.PutIfAbsent("key1", "newvalue1") + value, ok = rwMap.Get("key1") + assert.True(t, ok, "Should still find the key") + assert.Equal(t, "value1", value, "Should return the original value, not the new one") + + rwMap.PutIfAbsent("key2", "value2") + value, ok = rwMap.Get("key2") + assert.True(t, ok, "Should find the second key") + assert.Equal(t, "value2", value, "Should return the correct value for second key") + + assert.Equal(t, 2, rwMap.Size(), "Map should contain both keys") +} + +func TestRWMapRemove(t *testing.T) { + rwMap := utils.NewRWMap[string]() + + rwMap.Put("key1", "value1") + rwMap.Put("key2", "value2") + rwMap.Put("key3", "value3") + assert.Equal(t, 3, rwMap.Size(), "Map should contain 3 items") + + rwMap.Remove("key2") + assert.Equal(t, 2, rwMap.Size(), "Map should contain 2 items after removal") + + _, ok := rwMap.Get("key2") + assert.False(t, ok, "Removed key should not be found") + + value, ok := rwMap.Get("key1") + assert.True(t, ok, "Other keys should still exist") + assert.Equal(t, "value1", value, "Other values should be unchanged") + + value, ok = rwMap.Get("key3") + assert.True(t, ok, "Other keys should still exist") + assert.Equal(t, "value3", value, "Other values should be unchanged") + + rwMap.Remove("nonexistent") + assert.Equal(t, 2, rwMap.Size(), "Size should remain unchanged when removing non-existent key") +} + +func TestRWMapClear(t *testing.T) { + rwMap := utils.NewRWMap[int]() + + rwMap.Put("key1", 1) + rwMap.Put("key2", 2) + rwMap.Put("key3", 3) + assert.Equal(t, 3, rwMap.Size(), "Map should contain 3 items") + + rwMap.Clear() + assert.Equal(t, 0, rwMap.Size(), "Map should be empty after clear") + + _, ok := rwMap.Get("key1") + assert.False(t, ok, "All keys should be removed after clear") + _, ok = rwMap.Get("key2") + assert.False(t, ok, "All keys should be removed after clear") + _, ok = rwMap.Get("key3") + assert.False(t, ok, "All keys should be removed after clear") + + rwMap.Clear() + assert.Equal(t, 0, rwMap.Size(), "Clearing empty map should not cause issues") +} + +func TestRWMapClearWithDisposalFunc(t *testing.T) { + disposedValues := make([]int, 0) + disposalFunc := func(value int) bool { + disposedValues = append(disposedValues, value) + return true + } + rwMap := utils.NewRWMapWithDisposalFunc[int](disposalFunc) + + rwMap.Put("key1", 1) + rwMap.Put("key2", 2) + rwMap.Put("key3", 3) + + rwMap.Clear() + + assert.Equal(t, 0, rwMap.Size(), "Map should be empty after clear") + + assert.Len(t, disposedValues, 3, "Disposal function should be called for all values") + assert.Contains(t, disposedValues, 1, "Should dispose value 1") + assert.Contains(t, disposedValues, 2, "Should dispose value 2") + assert.Contains(t, disposedValues, 3, "Should dispose value 3") +} + +func TestRWMapGetAllEntries(t *testing.T) { + rwMap := utils.NewRWMap[string]() + + entries := rwMap.GetAllEntries() + assert.Empty(t, entries, "Empty map should return empty entries") + + rwMap.Put("key1", "value1") + rwMap.Put("key2", "value2") + rwMap.Put("key3", "value3") + + entries = rwMap.GetAllEntries() + assert.Len(t, entries, 3, "Should return all entries") + assert.Equal(t, "value1", entries["key1"], "Should contain correct value for key1") + assert.Equal(t, "value2", entries["key2"], "Should contain correct value for key2") + assert.Equal(t, "value3", entries["key3"], "Should contain correct value for key3") + + entries["key4"] = "value4" + _, ok := rwMap.Get("key4") + assert.False(t, ok, "Modifying returned entries should not affect original map") + assert.Equal(t, 3, rwMap.Size(), "Original map size should be unchanged") +} + +func TestRWMapReplaceCacheWithCopy(t *testing.T) { + sourceMap := utils.NewRWMap[int]() + sourceMap.Put("key1", 1) + sourceMap.Put("key2", 2) + sourceMap.Put("key3", 3) + + targetMap := utils.NewRWMap[int]() + targetMap.Put("oldkey1", 10) + targetMap.Put("oldkey2", 20) + + targetMap.ReplaceCacheWithCopy(sourceMap) + + assert.Equal(t, 3, targetMap.Size(), "Target map should have source map size") + + value, ok := targetMap.Get("key1") + assert.True(t, ok, "Target should contain source keys") + assert.Equal(t, 1, value, "Target should contain source values") + + value, ok = targetMap.Get("key2") + assert.True(t, ok, "Target should contain source keys") + assert.Equal(t, 2, value, "Target should contain source values") + + value, ok = targetMap.Get("key3") + assert.True(t, ok, "Target should contain source keys") + assert.Equal(t, 3, value, "Target should contain source values") + + _, ok = targetMap.Get("oldkey1") + assert.False(t, ok, "Old keys should be removed") + _, ok = targetMap.Get("oldkey2") + assert.False(t, ok, "Old keys should be removed") + + assert.Equal(t, 3, sourceMap.Size(), "Source map should be unchanged") +} + +func TestRWMapSize(t *testing.T) { + rwMap := utils.NewRWMap[string]() + + assert.Equal(t, 0, rwMap.Size(), "Empty map should have size 0") + + rwMap.Put("key1", "value1") + assert.Equal(t, 1, rwMap.Size(), "Map should have size 1 after adding one item") + + rwMap.Put("key2", "value2") + assert.Equal(t, 2, rwMap.Size(), "Map should have size 2 after adding two items") + + rwMap.Put("key1", "newvalue1") + assert.Equal(t, 2, rwMap.Size(), "Map size should not change when overwriting") + + rwMap.Remove("key1") + assert.Equal(t, 1, rwMap.Size(), "Map should have size 1 after removing one item") + + rwMap.Remove("key2") + assert.Equal(t, 0, rwMap.Size(), "Map should have size 0 after removing all items") + + rwMap.Remove("nonexistent") + assert.Equal(t, 0, rwMap.Size(), "Map size should not change when removing non-existent key") +} + +func TestRWMapConcurrency(t *testing.T) { + rwMap := utils.NewRWMap[int]() + numGoroutines := 100 + numOperationsPerGoroutine := 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines * 3) + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperationsPerGoroutine; j++ { + key := fmt.Sprintf("writer-%d-%d", id, j) + rwMap.Put(key, id*1000+j) + } + }(i) + } + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperationsPerGoroutine; j++ { + key := fmt.Sprintf("writer-%d-%d", id%10, j%10) + rwMap.Get(key) + } + }(i) + } + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperationsPerGoroutine; j++ { + key := fmt.Sprintf("compute-%d-%d", id, j) + rwMap.ComputeIfAbsent(key, func() int { return id*2000 + j }) + } + }(i) + } + + wg.Wait() + + size := rwMap.Size() + assert.True(t, size > 0, "Map should contain items after concurrent operations") + + entries := rwMap.GetAllEntries() + assert.Equal(t, size, len(entries), "GetAllEntries should return consistent number of items") +} diff --git a/.test/test/utils_test.go b/.test/test/utils_test.go index e147f027..c1a9c9db 100644 --- a/.test/test/utils_test.go +++ b/.test/test/utils_test.go @@ -17,26 +17,44 @@ package test import ( + "database/sql/driver" + "io" "testing" + mock_database_sql_driver "github.com/aws/aws-advanced-go-wrapper/.test/test/mocks/database_sql_driver" "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) func TestRollbackWithCurrentTx(t *testing.T) { - mockConn := MockConn{} - mockTx := NewMockTx() + ctrl := gomock.NewController(t) + defer ctrl.Finish() - utils.Rollback(&mockConn, mockTx) - assert.Equal(t, 1, *mockTx.rollbackCounter) - assert.Equal(t, 0, mockConn.execContextCounter) + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockTx := mock_database_sql_driver.NewMockTx(ctrl) + + mockTx.EXPECT().Rollback().Return(nil).Times(1) + + utils.Rollback(mockConn, mockTx) } func TestRollbackWithNilCurrentTx(t *testing.T) { - mockConn := MockConn{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockExecer := mock_database_sql_driver.NewMockExecerContext(ctrl) - utils.Rollback(&mockConn, nil) - assert.Equal(t, 1, mockConn.execContextCounter) + mockExecer.EXPECT().ExecContext(gomock.Any(), "rollback", gomock.Any()).Return(nil, nil).Times(1) + + // Create a combined mock that implements both Conn and ExecerContext + combinedMock := struct { + *mock_database_sql_driver.MockConn + *mock_database_sql_driver.MockExecerContext + }{mockConn, mockExecer} + + utils.Rollback(combinedMock, nil) } func TestCombineMaps(t *testing.T) { @@ -77,3 +95,166 @@ func TestCombineMaps(t *testing.T) { }, utils.CombineMaps(map5, map6)) } + +func TestGetFirstRowFromQuery(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Test with mock connection that returns data + mockConn := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows := mock_database_sql_driver.NewMockRows(ctrl) + + // Set up expectations for successful query + mockQueryer.EXPECT().QueryContext(gomock.Any(), "SELECT * FROM test", gomock.Any()).Return(mockRows, nil).Times(1) + mockRows.EXPECT().Columns().Return([]string{"col1", "col2", "col3"}).Times(1) + mockRows.EXPECT().Next(gomock.Any()).DoAndReturn(func(dest []driver.Value) error { + dest[0] = "value1" + dest[1] = "value2" + dest[2] = "value3" + return nil + }).Times(1) + mockRows.EXPECT().Close().Return(nil).Times(1) + + // Create a combined mock that implements both Conn and QueryerContext + combinedMock := struct { + *mock_database_sql_driver.MockConn + *mock_database_sql_driver.MockQueryerContext + }{mockConn, mockQueryer} + + result := utils.GetFirstRowFromQuery(combinedMock, "SELECT * FROM test") + assert.NotNil(t, result) + assert.Equal(t, 3, len(result)) + assert.Equal(t, "value1", result[0]) + assert.Equal(t, "value2", result[1]) + assert.Equal(t, "value3", result[2]) + + // Test with mock connection that returns no data + mockConn2 := mock_database_sql_driver.NewMockConn(ctrl) + mockQueryer2 := mock_database_sql_driver.NewMockQueryerContext(ctrl) + mockRows2 := mock_database_sql_driver.NewMockRows(ctrl) + + mockQueryer2.EXPECT().QueryContext(gomock.Any(), "SELECT * FROM empty", gomock.Any()).Return(mockRows2, nil).Times(1) + mockRows2.EXPECT().Columns().Return([]string{}).Times(1) + mockRows2.EXPECT().Next(gomock.Any()).Return(io.EOF).Times(1) // Still called even with empty columns + mockRows2.EXPECT().Close().Return(nil).Times(1) + + combinedMock2 := struct { + *mock_database_sql_driver.MockConn + *mock_database_sql_driver.MockQueryerContext + }{mockConn2, mockQueryer2} + + result = utils.GetFirstRowFromQuery(combinedMock2, "SELECT * FROM empty") + assert.Nil(t, result) + + // Test with connection that doesn't implement QueryerContext (just MockConn) + mockConn3 := mock_database_sql_driver.NewMockConn(ctrl) + result = utils.GetFirstRowFromQuery(mockConn3, "SELECT * FROM test") + assert.Nil(t, result) +} + +func TestFilterSliceFindFirst(t *testing.T) { + numbers := []int{1, 3, 4, 5, 6, 7} + result := utils.FilterSliceFindFirst(numbers, func(n int) bool { + return n%2 == 0 + }) + assert.Equal(t, 4, result) + + words := []string{"apple", "banana", "cherry", "blueberry"} + resultStr := utils.FilterSliceFindFirst(words, func(s string) bool { + return len(s) > 0 && s[0] == 'b' + }) + assert.Equal(t, "banana", resultStr) + + oddNumbers := []int{1, 3, 5, 7} + resultZero := utils.FilterSliceFindFirst(oddNumbers, func(n int) bool { + return n%2 == 0 + }) + assert.Equal(t, 0, resultZero) + + emptySlice := []int{} + resultEmpty := utils.FilterSliceFindFirst(emptySlice, func(n int) bool { + return n > 0 + }) + assert.Equal(t, 0, resultEmpty) +} + +func TestFilterSetFindFirst(t *testing.T) { + testMap := map[string]int{ + "apple": 1, + "banana": 2, + "cherry": 3, + "tomato": 4, + } + result := utils.FilterSetFindFirst(testMap, func(key string) bool { + return len(key) > 0 && key[0] == 't' + }) + assert.Equal(t, "tomato", result) + + intMap := map[int]string{ + 1: "one", + 2: "two", + 3: "three", + 4: "four", + } + resultInt := utils.FilterSetFindFirst(intMap, func(key int) bool { + return key%2 == 0 + }) + assert.True(t, resultInt == 2 || resultInt == 4) + + resultZero := utils.FilterSetFindFirst(testMap, func(key string) bool { + return len(key) > 0 && key[0] == 'z' + }) + assert.Equal(t, "", resultZero) + + emptyMap := map[string]int{} + resultEmpty := utils.FilterSetFindFirst(emptyMap, func(key string) bool { + return len(key) > 0 + }) + assert.Equal(t, "", resultEmpty) +} + +func TestFilterMapFindFirstValue(t *testing.T) { + testMap := map[string]string{ + "fruit1": "apple", + "fruit2": "banana", + "fruit3": "cherry", + "fruit4": "blueberry", + } + result := utils.FilterMapFindFirstValue(testMap, func(value string) bool { + return len(value) > 0 && value[0] == 'b' + }) + assert.True(t, result == "banana" || result == "blueberry") + + intMap := map[string]int{ + "a": 1, + "b": 2, + "c": 3, + "d": 4, + } + resultInt := utils.FilterMapFindFirstValue(intMap, func(value int) bool { + return value%2 == 0 + }) + assert.True(t, resultInt == 2 || resultInt == 4) + + resultZero := utils.FilterMapFindFirstValue(testMap, func(value string) bool { + return len(value) > 0 && value[0] == 'z' + }) + assert.Equal(t, "", resultZero) + + emptyMap := map[string]string{} + resultEmpty := utils.FilterMapFindFirstValue(emptyMap, func(value string) bool { + return len(value) > 0 + }) + assert.Equal(t, "", resultEmpty) +} + +func TestPair(t *testing.T) { + pair := utils.NewPair("hello", 42) + assert.Equal(t, "hello", pair.GetLeft(), "Should store left value correctly") + assert.Equal(t, 42, pair.GetRight(), "Should store right value correctly") + + intPair := utils.NewPair(10, 20) + assert.Equal(t, 10, intPair.GetLeft(), "Should store left int correctly") + assert.Equal(t, 20, intPair.GetRight(), "Should store right int correctly") +} diff --git a/aws-secrets-manager/aws_secrets_manager_connection_plugin.go b/aws-secrets-manager/aws_secrets_manager_connection_plugin.go index aa483662..b8113554 100644 --- a/aws-secrets-manager/aws_secrets_manager_connection_plugin.go +++ b/aws-secrets-manager/aws_secrets_manager_connection_plugin.go @@ -35,7 +35,8 @@ import ( ) func init() { - awssql.UsePluginFactory("awsSecretsManager", NewAwsSecretsManagerPluginFactory()) + awssql.UsePluginFactory(driver_infrastructure.SECRETS_MANAGER_PLUGIN_CODE, + NewAwsSecretsManagerPluginFactory()) } var fetchCredentialsCounterName = "secretsManager.fetchCredentials.count" @@ -120,6 +121,10 @@ func NewAwsSecretsManagerPlugin(pluginService driver_infrastructure.PluginServic }, err } +func (awsSecretsManagerPlugin *AwsSecretsManagerPlugin) GetPluginCode() string { + return driver_infrastructure.SECRETS_MANAGER_PLUGIN_CODE +} + func (awsSecretsManagerPlugin *AwsSecretsManagerPlugin) GetSubscribedMethods() []string { return []string{plugin_helpers.CONNECT_METHOD, plugin_helpers.FORCE_CONNECT_METHOD} } diff --git a/awssql/driver/connection_plugin_chain_builder.go b/awssql/driver/connection_plugin_chain_builder.go index 3cc3573a..32a616e2 100644 --- a/awssql/driver/connection_plugin_chain_builder.go +++ b/awssql/driver/connection_plugin_chain_builder.go @@ -36,15 +36,16 @@ type PluginFactoryWeight struct { } var pluginWeightByCode = map[string]int{ - "readWriteSplitting": 600, - "failover": 700, - "efm": 800, - "limitless": 950, - "iam": 1000, - "awsSecretsManager": 1100, - "federatedAuth": 1200, - "okta": 1300, - "executionTime": WEIGHT_RELATIVE_TO_PRIOR_PLUGIN, + driver_infrastructure.BLUE_GREEN_PLUGIN_CODE: 550, + driver_infrastructure.READ_WRITE_SPLITTING_PLUGIN_CODE: 600, + driver_infrastructure.FAILOVER_PLUGIN_CODE: 700, + driver_infrastructure.EFM_PLUGIN_CODE: 800, + driver_infrastructure.LIMITLESS_PLUGIN_CODE: 950, + driver_infrastructure.IAM_PLUGIN_CODE: 1000, + driver_infrastructure.SECRETS_MANAGER_PLUGIN_CODE: 1100, + driver_infrastructure.ADFS_PLUGIN_CODE: 1200, + driver_infrastructure.OKTA_PLUGIN_CODE: 1300, + driver_infrastructure.EXECUTION_TIME_PLUGIN_CODE: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN, } type ConnectionPluginChainBuilder struct { @@ -105,11 +106,11 @@ func (builder *ConnectionPluginChainBuilder) GetPlugins( } } - defaultPlugin := driver_infrastructure.ConnectionPlugin(&plugins.DefaultPlugin{ + defaultPlugin := &plugins.DefaultPlugin{ PluginService: pluginService, DefaultConnProvider: pluginManager.GetDefaultConnectionProvider(), ConnProviderManager: pluginManager.GetConnectionProviderManager(), - }) + } resultPlugins = append(resultPlugins, defaultPlugin) if pluginsSorted { slog.Info(fmt.Sprintf("Plugins order has been rearranged. The following order is in effect: '%v'.", getFactoryOrder(resultPlugins))) diff --git a/awssql/driver/driver.go b/awssql/driver/driver.go index 957a3c31..40f219ef 100644 --- a/awssql/driver/driver.go +++ b/awssql/driver/driver.go @@ -27,6 +27,7 @@ import ( "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" "github.com/aws/aws-advanced-go-wrapper/awssql/plugin_helpers" "github.com/aws/aws-advanced-go-wrapper/awssql/plugins" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/bg" "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/efm" "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/limitless" "github.com/aws/aws-advanced-go-wrapper/awssql/plugins/read_write_splitting" @@ -35,11 +36,12 @@ import ( ) var pluginFactoryByCode = map[string]driver_infrastructure.ConnectionPluginFactory{ - "failover": plugins.NewFailoverPluginFactory(), - "efm": efm.NewHostMonitoringPluginFactory(), - "limitless": limitless.NewLimitlessPluginFactory(), - "executionTime": plugins.NewExecutionTimePluginFactory(), - "readWriteSplitting": read_write_splitting.NewReadWriteSplittingPluginFactory(), + driver_infrastructure.FAILOVER_PLUGIN_CODE: plugins.NewFailoverPluginFactory(), + driver_infrastructure.EFM_PLUGIN_CODE: efm.NewHostMonitoringPluginFactory(), + driver_infrastructure.LIMITLESS_PLUGIN_CODE: limitless.NewLimitlessPluginFactory(), + driver_infrastructure.EXECUTION_TIME_PLUGIN_CODE: plugins.NewExecutionTimePluginFactory(), + driver_infrastructure.READ_WRITE_SPLITTING_PLUGIN_CODE: read_write_splitting.NewReadWriteSplittingPluginFactory(), + driver_infrastructure.BLUE_GREEN_PLUGIN_CODE: bg.NewBlueGreenPluginFactory(), } var underlyingDriverList = map[string]driver.Driver{} @@ -86,7 +88,7 @@ func (d *AwsWrapperDriver) Open(dsn string) (driver.Conn, error) { } hostListProviderService := driver_infrastructure.HostListProviderService(pluginServiceImpl) - provider := hostListProviderService.CreateHostListProvider(props, dsn) + provider := hostListProviderService.CreateHostListProvider(props) hostListProviderService.SetHostListProvider(provider) telemetryCtx, ctx := pluginManager.GetTelemetryFactory().OpenTelemetryContext(telemetry.TELEMETRY_OPEN_CONNECTION, telemetry.TOP_LEVEL, nil) @@ -96,7 +98,7 @@ func (d *AwsWrapperDriver) Open(dsn string) (driver.Conn, error) { pluginManager.SetTelemetryContext(context.TODO()) }() - err = pluginManager.InitHostProvider(dsn, props, hostListProviderService) + err = pluginManager.InitHostProvider(props, hostListProviderService) if err != nil { return nil, err } @@ -157,7 +159,7 @@ func GetUnderlyingDriver(name string) driver.Driver { return underlyingDriverList[name] } -// This cleans up all long standing caches. To be called at the end of program, not each time a Conn is closed. +// This cleans up all long-standing caches. To be called at the end of program, not each time a Conn is closed. func ClearCaches() { driver_infrastructure.ClearCaches() plugin_helpers.ClearCaches() @@ -332,6 +334,10 @@ func (c *AwsWrapperConn) setReadWriteMode(ctx context.Context) error { return err } +func (c *AwsWrapperConn) UnwrapPlugin(pluginCode string) driver_infrastructure.ConnectionPlugin { + return c.pluginManager.UnwrapPlugin(pluginCode) +} + type AwsWrapperStmt struct { underlyingConn driver.Conn underlyingStmt driver.Stmt diff --git a/awssql/driver_infrastructure/bg_helpers.go b/awssql/driver_infrastructure/bg_helpers.go new file mode 100644 index 00000000..f6090a09 --- /dev/null +++ b/awssql/driver_infrastructure/bg_helpers.go @@ -0,0 +1,291 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package driver_infrastructure + +import ( + "database/sql/driver" + "fmt" + "log/slog" + "strings" + + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" +) + +type BlueGreenIntervalRate int + +const ( + INVALID BlueGreenIntervalRate = iota - 1 + BASELINE + INCREASED + HIGH +) + +const ( + BLUE_GREEN_SOURCE string = "BLUE_GREEN_DEPLOYMENT_SOURCE" + BLUE_GREEN_TARGET string = "BLUE_GREEN_DEPLOYMENT_TARGET" +) + +const ( + AVAILABLE string = "AVAILABLE" + SWITCHOVER_INITIATED string = "SWITCHOVER_INITIATED" + SWITCHOVER_IN_PROGRESS string = "SWITCHOVER_IN_PROGRESS" + SWITCHOVER_IN_POST_PROCESSING string = "SWITCHOVER_IN_POST_PROCESSING" + SWITCHOVER_COMPLETED string = "SWITCHOVER_COMPLETED" +) + +var ( + NOT_CREATED = BlueGreenPhase{"NOT_CREATED", 0, false} + CREATED = BlueGreenPhase{"CREATED", 1, false} + PREPARATION = BlueGreenPhase{"PREPARATION", 2, true} + IN_PROGRESS = BlueGreenPhase{"IN_PROGRESS", 3, true} + POST = BlueGreenPhase{"POST", 4, true} + COMPLETED = BlueGreenPhase{"COMPLETED", 5, true} +) + +var ( + SOURCE = BlueGreenRole{"SOURCE", 0} + TARGET = BlueGreenRole{"TARGET", 1} +) + +var blueGreenRoleMapping = map[string]BlueGreenRole{ + BLUE_GREEN_SOURCE: SOURCE, + BLUE_GREEN_TARGET: TARGET, +} + +var blueGreenStatusMapping = map[string]BlueGreenPhase{ + AVAILABLE: CREATED, + SWITCHOVER_INITIATED: PREPARATION, + SWITCHOVER_IN_PROGRESS: IN_PROGRESS, + SWITCHOVER_IN_POST_PROCESSING: POST, + SWITCHOVER_COMPLETED: COMPLETED, +} + +type BlueGreenPhase struct { + name string + phase int + isActiveSwitchoverOrCompleted bool +} + +func (b BlueGreenPhase) GetName() string { + return b.name +} + +func (b BlueGreenPhase) GetPhase() int { + return b.phase +} + +func (b BlueGreenPhase) IsActiveSwitchoverOrCompleted() bool { + return b.isActiveSwitchoverOrCompleted +} + +func (b BlueGreenPhase) IsZero() bool { + return b.name == "" +} + +func (b BlueGreenPhase) Equals(other BlueGreenPhase) bool { + return b.name == other.name && b.phase == other.phase && b.isActiveSwitchoverOrCompleted == other.isActiveSwitchoverOrCompleted +} + +func ParsePhase(statusKey string) BlueGreenPhase { + if statusKey == "" { + return NOT_CREATED + } + + phase, ok := blueGreenStatusMapping[strings.ToUpper(statusKey)] + if !ok { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.unknownStatus", statusKey)) + } + return phase +} + +type BlueGreenRole struct { + name string + value int +} + +func (b BlueGreenRole) GetName() string { + return b.name +} + +func (b BlueGreenRole) GetValue() int { + return b.value +} + +func ParseRole(roleKey string) BlueGreenRole { + role, ok := blueGreenRoleMapping[strings.ToUpper(roleKey)] + if !ok { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.unknownRole", roleKey)) + } + return role +} + +func (b BlueGreenRole) IsZero() bool { + return b.name == "" && b.value == 0 +} + +func (b BlueGreenRole) String() string { + return fmt.Sprintf("BlueGreenRole [name: %s, value: %d]", b.name, b.value) +} + +type BlueGreenStatus struct { + bgId string + currentPhase BlueGreenPhase + connectRoutings []ConnectRouting + executeRoutings []ExecuteRouting + roleByHost *utils.RWMap[BlueGreenRole] + correspondingHosts *utils.RWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]] +} + +func NewBgStatus(id string, phase BlueGreenPhase, connectRoutings []ConnectRouting, executeRoutings []ExecuteRouting, + roleByHost *utils.RWMap[BlueGreenRole], correspondingHosts *utils.RWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]]) BlueGreenStatus { + return BlueGreenStatus{ + bgId: id, + currentPhase: phase, + connectRoutings: connectRoutings, + executeRoutings: executeRoutings, + roleByHost: roleByHost, + correspondingHosts: correspondingHosts, + } +} + +func (b BlueGreenStatus) GetCurrentPhase() BlueGreenPhase { + return b.currentPhase +} + +func (b BlueGreenStatus) GetConnectRoutings() []ConnectRouting { + return b.connectRoutings +} + +func (b BlueGreenStatus) GetExecuteRoutings() []ExecuteRouting { + return b.executeRoutings +} + +func (b BlueGreenStatus) GetBgId() string { + return b.bgId +} + +func (b BlueGreenStatus) GetCorrespondingHosts() map[string]utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo] { + if b.correspondingHosts == nil { + return nil + } + return b.correspondingHosts.GetAllEntries() +} + +func (b BlueGreenStatus) GetRole(hostInfo *host_info_util.HostInfo) (role BlueGreenRole, ok bool) { + if hostInfo.IsNil() || b.roleByHost == nil { + return + } + return b.roleByHost.Get(strings.ToLower(hostInfo.GetHost())) +} + +func (b BlueGreenStatus) IsZero() bool { + return b.bgId == "" && + b.currentPhase.IsZero() && + b.connectRoutings == nil && + b.executeRoutings == nil && + b.roleByHost == nil && + b.correspondingHosts == nil +} + +func (b BlueGreenStatus) MatchIdPhaseAndLen(other BlueGreenStatus) bool { + return b.bgId == other.bgId && + b.currentPhase == other.currentPhase && + len(b.connectRoutings) == len(other.connectRoutings) && + len(b.executeRoutings) == len(other.executeRoutings) && + b.roleByHost.Size() == other.roleByHost.Size() && + b.correspondingHosts.Size() == other.correspondingHosts.Size() +} + +func (b BlueGreenStatus) String() string { + roleByHostMapStr := "-" + connectRoutingStr := "-" + executeRoutingStr := "-" + + return fmt.Sprintf("BlueGreenStatus [\n"+ + "\tbgId: %s,\n"+ + "\tphase: %s,\n"+ + "\tconnect routing: %s,\n"+ + "\texecute routing: %s,\n"+ + "\troleByHost: %s\n"+ + "]", + b.bgId, b.currentPhase.GetName(), connectRoutingStr, executeRoutingStr, roleByHostMapStr) +} + +type ConnectRouting interface { + IsMatch(hostInfo *host_info_util.HostInfo, hostRole BlueGreenRole) bool + + Apply( + plugin ConnectionPlugin, + hostInfo *host_info_util.HostInfo, + properties map[string]string, + isInitialConnection bool, + pluginService PluginService, + ) (driver.Conn, error) +} + +type ExecuteRouting interface { + IsMatch(hostInfo *host_info_util.HostInfo, hostRole BlueGreenRole) bool + Apply( + plugin ConnectionPlugin, + properties map[string]string, + pluginService PluginService, + methodName string, + methodFunc ExecuteFunc, + methodArgs ...any, + ) RoutingResultHolder +} + +type BlueGreenResult struct { + Version string + Endpoint string + Port int + Role string + Status string +} + +func (b *BlueGreenResult) String() string { + return fmt.Sprintf("BlueGreenResult [\n"+ + "\tversion: %s,\n"+ + "\tendpoint: %s,\n"+ + "\tport routing: %d,\n"+ + "\trole routing: %s,\n"+ + "\tstatus: %s\n"+ + "]", + b.Version, b.Endpoint, b.Port, b.Role, b.Status) +} + +type EmptyResult struct{} + +type RoutingResultHolder struct { + WrappedReturnValue any + WrappedReturnValue2 any + WrappedOk bool + WrappedErr error +} + +var EMPTY_VAL = EmptyResult{} +var EMPTY_ROUTING_RESULT_HOLDER = RoutingResultHolder{WrappedReturnValue: EMPTY_VAL} + +func (r RoutingResultHolder) GetResult() (any, any, bool, error) { + return r.WrappedReturnValue, r.WrappedReturnValue2, r.WrappedOk, r.WrappedErr +} + +func (r RoutingResultHolder) IsPresent() bool { + return r != EMPTY_ROUTING_RESULT_HOLDER +} diff --git a/awssql/driver_infrastructure/cluster_topology_monitor.go b/awssql/driver_infrastructure/cluster_topology_monitor.go index 6c3d65f4..bce9d7f9 100644 --- a/awssql/driver_infrastructure/cluster_topology_monitor.go +++ b/awssql/driver_infrastructure/cluster_topology_monitor.go @@ -238,28 +238,29 @@ func (c *ClusterTopologyMonitorImpl) openAnyConnectionAndUpdateTopology() ([]*ho if c.loadConn(c.monitoringConn) == nil { // Open a new connection. conn, err := c.pluginService.ForceConnect(c.initialHostInfo, utils.CreateMapCopy(c.monitoringProps)) - if err != nil { + if err != nil || conn == nil { // Can't connect. return nil, err } if c.monitoringConn.CompareAndSwap(emptyContainer, ConnectionContainer{conn}) { - slog.Debug(error_util.GetMessage("ClusterTopologyMonitorImpl.openedMonitoringConnection", c.initialHostInfo.Host)) + slog.Debug(error_util.GetMessage("ClusterTopologyMonitorImpl.openedMonitoringConnection", c.initialHostInfo.GetHost())) writerId, getWriterNameErr := c.databaseDialect.GetWriterHostName(conn) if getWriterNameErr == nil && writerId != "" { c.isVerifiedWriterConn = true writerVerifiedByThisRoutine = true - if utils.IsRdsDns(c.initialHostInfo.Host) { + if utils.IsRdsInstance(c.initialHostInfo.GetHost()) { c.writerHostInfo.Store(c.initialHostInfo) + slog.Debug(error_util.GetMessage("ClusterTopologyMonitorImpl.writerMonitoringConnection", c.writerHostInfo.Load().GetHost())) } else { hostId := c.databaseDialect.GetHostName(c.loadConn(c.monitoringConn)) if hostId != "" { c.writerHostInfo.Store(c.createHost(hostId, true, 0, time.Time{})) + slog.Debug(error_util.GetMessage("ClusterTopologyMonitorImpl.writerMonitoringConnection", c.writerHostInfo.Load().GetHost())) } } - slog.Debug(error_util.GetMessage("ClusterTopologyMonitorImpl.writerMonitoringConnection", c.writerHostInfo.Load().Host)) } } else { // Monitoring connection has already been set by other routine, close new connection as we don't need it. @@ -352,7 +353,7 @@ func (c *ClusterTopologyMonitorImpl) createHost(hostName string, isWriter bool, endpoint := c.getHostEndpoint(hostName) port := c.clusterInstanceTemplate.Port if port == host_info_util.HOST_NO_PORT { - if c.initialHostInfo.Port != host_info_util.HOST_NO_PORT { + if c.initialHostInfo.IsPortSpecified() { port = c.initialHostInfo.Port } else { port = c.hostListProvider.databaseDialect.GetDefaultPort() @@ -397,7 +398,7 @@ func (c *ClusterTopologyMonitorImpl) notifyChannel(channel chan bool) { } func (c *ClusterTopologyMonitorImpl) Run(wg *sync.WaitGroup) { - slog.Debug(error_util.GetMessage("ClusterTopologyMonitorImpl.startMonitoringRoutine", c.initialHostInfo.Host)) + slog.Debug(error_util.GetMessage("ClusterTopologyMonitorImpl.startMonitoringRoutine", c.initialHostInfo.GetHost())) for !c.stop.Load() { if c.isInPanicMode() { diff --git a/awssql/driver_infrastructure/connection_plugin.go b/awssql/driver_infrastructure/connection_plugin.go index e6e79a6e..8793c928 100644 --- a/awssql/driver_infrastructure/connection_plugin.go +++ b/awssql/driver_infrastructure/connection_plugin.go @@ -18,6 +18,7 @@ package driver_infrastructure import ( "database/sql/driver" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" ) @@ -35,5 +36,6 @@ type ConnectionPlugin interface { GetHostSelectorStrategy(strategy string) (HostSelector, error) NotifyConnectionChanged(changes map[HostChangeOptions]bool) OldConnectionSuggestedAction NotifyHostListChanged(changes map[string]map[HostChangeOptions]bool) - InitHostProvider(initialUrl string, props map[string]string, hostListProviderService HostListProviderService, initHostProviderFunc func() error) error + InitHostProvider(props map[string]string, hostListProviderService HostListProviderService, initHostProviderFunc func() error) error + GetPluginCode() string } diff --git a/awssql/driver_infrastructure/database_dialect.go b/awssql/driver_infrastructure/database_dialect.go index 24e900a8..91192f90 100644 --- a/awssql/driver_infrastructure/database_dialect.go +++ b/awssql/driver_infrastructure/database_dialect.go @@ -27,7 +27,7 @@ type DatabaseDialect interface { GetServerVersionQuery() string GetDialectUpdateCandidates() []string IsDialect(conn driver.Conn) bool - GetHostListProvider(props map[string]string, initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider + GetHostListProvider(props map[string]string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider DoesStatementSetAutoCommit(statement string) (bool, bool) DoesStatementSetReadOnly(statement string) (bool, bool) DoesStatementSetCatalog(statement string) (string, bool) @@ -52,3 +52,10 @@ type AuroraLimitlessDialect interface { GetLimitlessRouterEndpointQuery() string DatabaseDialect } + +type BlueGreenDialect interface { + GetBlueGreenStatus(conn driver.Conn) []BlueGreenResult + IsBlueGreenStatusAvailable(conn driver.Conn) bool + DatabaseDialect +} + diff --git a/awssql/driver_infrastructure/dsn_host_list_provider.go b/awssql/driver_infrastructure/dsn_host_list_provider.go index bf28d259..1de4813b 100644 --- a/awssql/driver_infrastructure/dsn_host_list_provider.go +++ b/awssql/driver_infrastructure/dsn_host_list_provider.go @@ -31,19 +31,19 @@ import ( type DsnHostListProvider struct { isSingleWriterConnectionString bool - dsn string + props map[string]string hostListProviderService HostListProviderService isInitialized bool hostList []*host_info_util.HostInfo initialHost string } -func NewDsnHostListProvider(props map[string]string, dsn string, hostListProviderService HostListProviderService) *DsnHostListProvider { +func NewDsnHostListProvider(props map[string]string, hostListProviderService HostListProviderService) *DsnHostListProvider { isSingleWriterConnectionString := property_util.GetVerifiedWrapperPropertyValue[bool](props, property_util.SINGLE_WRITER_DSN) initialHost := property_util.GetVerifiedWrapperPropertyValue[string](props, property_util.HOST) return &DsnHostListProvider{ isSingleWriterConnectionString, - dsn, + props, hostListProviderService, false, []*host_info_util.HostInfo{}, @@ -56,7 +56,7 @@ func (c *DsnHostListProvider) init() error { return nil } - hosts, err := utils.GetHostsFromDsn(c.dsn, c.isSingleWriterConnectionString) + hosts, err := utils.GetHostsFromProps(c.props, c.isSingleWriterConnectionString) if err != nil { return err } @@ -102,7 +102,10 @@ func (c *DsnHostListProvider) CreateHost(hostName string, hostRole host_info_uti builder := host_info_util.NewHostInfoBuilder() weight := int(math.Round(lag)*100 + math.Round(cpu)) port := c.hostListProviderService.GetDialect().GetDefaultPort() - builder.SetHost(c.initialHost).SetPort(port).SetRole(hostRole).SetAvailability(host_info_util.AVAILABLE).SetWeight(weight).SetLastUpdateTime(lastUpdateTime) + if hostName == "" { + hostName = c.initialHost + } + builder.SetHost(hostName).SetPort(port).SetRole(hostRole).SetAvailability(host_info_util.AVAILABLE).SetWeight(weight).SetLastUpdateTime(lastUpdateTime) hostInfo, _ := builder.Build() return hostInfo } diff --git a/awssql/driver_infrastructure/fixed_value_types.go b/awssql/driver_infrastructure/fixed_value_types.go index 887aadb1..a3e8f314 100644 --- a/awssql/driver_infrastructure/fixed_value_types.go +++ b/awssql/driver_infrastructure/fixed_value_types.go @@ -16,6 +16,19 @@ package driver_infrastructure +const ( + BLUE_GREEN_PLUGIN_CODE string = "bg" + READ_WRITE_SPLITTING_PLUGIN_CODE string = "readWriteSplitting" + FAILOVER_PLUGIN_CODE string = "failover" + EFM_PLUGIN_CODE string = "efm" + LIMITLESS_PLUGIN_CODE string = "limitless" + IAM_PLUGIN_CODE string = "iam" + SECRETS_MANAGER_PLUGIN_CODE string = "awsSecretsManager" + ADFS_PLUGIN_CODE string = "federatedAuth" + OKTA_PLUGIN_CODE string = "okta" + EXECUTION_TIME_PLUGIN_CODE string = "executionTime" +) + type HostChangeOptions int const ( diff --git a/awssql/driver_infrastructure/host_list_provider.go b/awssql/driver_infrastructure/host_list_provider.go index a5541aa2..3e97b04b 100644 --- a/awssql/driver_infrastructure/host_list_provider.go +++ b/awssql/driver_infrastructure/host_list_provider.go @@ -18,8 +18,9 @@ package driver_infrastructure import ( "database/sql/driver" - "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" ) type HostListProvider interface { diff --git a/awssql/driver_infrastructure/monitoring_rds_host_list_provider.go b/awssql/driver_infrastructure/monitoring_rds_host_list_provider.go index 83a7e4c5..7dd83496 100644 --- a/awssql/driver_infrastructure/monitoring_rds_host_list_provider.go +++ b/awssql/driver_infrastructure/monitoring_rds_host_list_provider.go @@ -51,7 +51,6 @@ func NewMonitoringRdsHostListProvider( hostListProviderService HostListProviderService, databaseDialect TopologyAwareDialect, properties map[string]string, - originalDsn string, pluginService PluginService) *MonitoringRdsHostListProvider { clusterTopologyMonitorsMutex.Lock() if clusterTopologyMonitors == nil { @@ -88,7 +87,7 @@ func NewMonitoringRdsHostListProvider( TopologyCache.Put(m.clusterId, existingHosts, TOPOLOGY_CACHE_EXPIRATION_NANO) } } - m.RdsHostListProvider = NewRdsHostListProvider(hostListProviderService, databaseDialect, properties, originalDsn, queryForTopologyFunc, clusterIdChangedFunc) + m.RdsHostListProvider = NewRdsHostListProvider(hostListProviderService, databaseDialect, properties, queryForTopologyFunc, clusterIdChangedFunc) return m } diff --git a/awssql/driver_infrastructure/mysql_database_dialects.go b/awssql/driver_infrastructure/mysql_database_dialects.go index 47dffe18..fabab5a2 100644 --- a/awssql/driver_infrastructure/mysql_database_dialects.go +++ b/awssql/driver_infrastructure/mysql_database_dialects.go @@ -67,10 +67,9 @@ func (m *MySQLDatabaseDialect) IsDialect(conn driver.Conn) bool { func (m *MySQLDatabaseDialect) GetHostListProvider( props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { - return HostListProvider(NewDsnHostListProvider(props, initialDsn, hostListProviderService)) + return NewDsnHostListProvider(props, hostListProviderService) } func (m *MySQLDatabaseDialect) GetSetAutoCommitQuery(autoCommit bool) (string, error) { @@ -193,6 +192,16 @@ func (m *RdsMySQLDatabaseDialect) IsDialect(conn driver.Conn) bool { return false } +func (m *RdsMySQLDatabaseDialect) GetBlueGreenStatus(conn driver.Conn) []BlueGreenResult { + bgStatusQuery := "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" + return mySqlGetBlueGreenStatus(conn, bgStatusQuery) +} + +func (m *RdsMySQLDatabaseDialect) IsBlueGreenStatusAvailable(conn driver.Conn) bool { + topologyTableExistQuery := "SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'" + return utils.GetFirstRowFromQuery(conn, topologyTableExistQuery) != nil +} + type MySQLTopologyAwareDatabaseDialect struct { MySQLDatabaseDialect } @@ -227,27 +236,25 @@ func (m *MySQLTopologyAwareDatabaseDialect) GetWriterHostName(conn driver.Conn) func (m *MySQLTopologyAwareDatabaseDialect) GetHostListProvider( props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { - return m.getTopologyAwareHostListProvider(m, props, initialDsn, hostListProviderService, pluginService) + return m.getTopologyAwareHostListProvider(m, props, hostListProviderService, pluginService) } func (m *MySQLTopologyAwareDatabaseDialect) getTopologyAwareHostListProvider( dialect TopologyAwareDialect, props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { pluginsProp := property_util.GetVerifiedWrapperPropertyValue[string](props, property_util.PLUGINS) if strings.Contains(pluginsProp, "failover") { slog.Debug(error_util.GetMessage("DatabaseDialect.usingMonitoringHostListProvider")) - return HostListProvider(NewMonitoringRdsHostListProvider(hostListProviderService, dialect, props, initialDsn, pluginService)) + return NewMonitoringRdsHostListProvider(hostListProviderService, dialect, props, pluginService) } slog.Debug(error_util.GetMessage("DatabaseDialect.usingRdsHostListProvider")) - return HostListProvider(NewRdsHostListProvider(hostListProviderService, dialect, props, initialDsn, nil, nil)) + return NewRdsHostListProvider(hostListProviderService, dialect, props, nil, nil) } type AuroraMySQLDatabaseDialect struct { @@ -289,10 +296,9 @@ func (m *AuroraMySQLDatabaseDialect) GetWriterHostName(conn driver.Conn) (string func (m *AuroraMySQLDatabaseDialect) GetHostListProvider( props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { - return m.getTopologyAwareHostListProvider(m, props, initialDsn, hostListProviderService, pluginService) + return m.getTopologyAwareHostListProvider(m, props, hostListProviderService, pluginService) } func (m *AuroraMySQLDatabaseDialect) GetTopology(conn driver.Conn, provider HostListProvider) ([]*host_info_util.HostInfo, error) { @@ -317,7 +323,7 @@ func (m *AuroraMySQLDatabaseDialect) GetTopology(conn driver.Conn, provider Host defer rows.Close() } - hosts := []*host_info_util.HostInfo{} + var hosts []*host_info_util.HostInfo if rows == nil { // Query returned an empty host list, no processing required. return hosts, nil @@ -355,6 +361,16 @@ func (m *AuroraMySQLDatabaseDialect) GetTopology(conn driver.Conn, provider Host return hosts, nil } +func (m *AuroraMySQLDatabaseDialect) GetBlueGreenStatus(conn driver.Conn) []BlueGreenResult { + bgStatusQuery := "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" + return mySqlGetBlueGreenStatus(conn, bgStatusQuery) +} + +func (m *AuroraMySQLDatabaseDialect) IsBlueGreenStatusAvailable(conn driver.Conn) bool { + topologyTableExistQuery := "SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'" + return utils.GetFirstRowFromQuery(conn, topologyTableExistQuery) != nil +} + type RdsMultiAzClusterMySQLDatabaseDialect struct { MySQLTopologyAwareDatabaseDialect } @@ -422,7 +438,7 @@ func (r *RdsMultiAzClusterMySQLDatabaseDialect) processTopologyQueryResults( provider HostListProvider, writerHostId string, rows driver.Rows) []*host_info_util.HostInfo { - hosts := []*host_info_util.HostInfo{} + var hosts []*host_info_util.HostInfo row := make([]driver.Value, len(rows.Columns())) err := rows.Next(row) for err == nil && len(row) > 1 { @@ -483,7 +499,7 @@ func (r *RdsMultiAzClusterMySQLDatabaseDialect) getWriterHostId(conn driver.Conn return r.getHostIdOfCurrentConnection(conn) } - var sourceIndex int = -1 + var sourceIndex = -1 for i, name := range columnNames { if name == "Source_Server_Id" { sourceIndex = i @@ -537,8 +553,55 @@ func (r *RdsMultiAzClusterMySQLDatabaseDialect) getHostIdOfCurrentConnection(con func (r *RdsMultiAzClusterMySQLDatabaseDialect) GetHostListProvider( props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { - return r.getTopologyAwareHostListProvider(r, props, initialDsn, hostListProviderService, pluginService) + return r.getTopologyAwareHostListProvider(r, props, hostListProviderService, pluginService) +} + +func mySqlGetBlueGreenStatus(conn driver.Conn, query string) []BlueGreenResult { + return getBlueGreenStatus(conn, query, utils.MySqlConvertValToString) +} + +func getBlueGreenStatus(conn driver.Conn, query string, convertFunc func(driver.Value) (string, bool)) []BlueGreenResult { + queryerCtx, ok := conn.(driver.QueryerContext) + if !ok { + // Unable to query, conn does not implement QueryerContext. + slog.Warn(error_util.GetMessage("Conn.doesNotImplementRequiredInterface", "driver.QueryerContext")) + return nil + } + + rows, err := queryerCtx.QueryContext(context.Background(), query, nil) + if err != nil { + // Query failed. + slog.Warn(error_util.GetMessage("BlueGreenDeployment.errorQueryingStatusTable", err)) + return nil + } + if rows != nil { + defer rows.Close() + } + + var statuses []BlueGreenResult + row := make([]driver.Value, len(rows.Columns())) + for rows.Next(row) == nil { + if len(row) > 4 { + version, ok1 := convertFunc(row[0]) + endpoint, ok2 := convertFunc(row[1]) + portAsFloat, ok3 := row[2].(int64) + role, ok4 := convertFunc(row[3]) + status, ok5 := convertFunc(row[4]) + + if !ok1 || !ok2 || !ok3 || !ok4 || !ok5 { + continue + } + statuses = append(statuses, BlueGreenResult{ + Version: version, + Endpoint: endpoint, + Port: int(portAsFloat), + Role: role, + Status: status, + }) + } + } + + return statuses } diff --git a/awssql/driver_infrastructure/pg_database_dialects.go b/awssql/driver_infrastructure/pg_database_dialects.go index f4dfd5e1..713c9420 100644 --- a/awssql/driver_infrastructure/pg_database_dialects.go +++ b/awssql/driver_infrastructure/pg_database_dialects.go @@ -59,10 +59,9 @@ func (p *PgDatabaseDialect) IsDialect(conn driver.Conn) bool { func (p *PgDatabaseDialect) GetHostListProvider( props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { - return HostListProvider(NewDsnHostListProvider(props, initialDsn, hostListProviderService)) + return NewDsnHostListProvider(props, hostListProviderService) } func (p *PgDatabaseDialect) DoesStatementSetAutoCommit(statement string) (bool, bool) { @@ -177,6 +176,16 @@ func (m *RdsPgDatabaseDialect) IsDialect(conn driver.Conn) bool { hasExtensions[1] == false // If aurora_stat_utils is present then it should be treated as an Aurora cluster, not an RDS cluster. } +func (m *RdsPgDatabaseDialect) GetBlueGreenStatus(conn driver.Conn) []BlueGreenResult { + bgStatusQuery := "SELECT version, endpoint, port, role, status FROM rds_tools.show_topology('aws_advanced_go_wrapper-" + driver_info.AWS_ADVANCED_GO_WRAPPER_VERSION + "')" + return pgGetBlueGreenStatus(conn, bgStatusQuery) +} + +func (m *RdsPgDatabaseDialect) IsBlueGreenStatusAvailable(conn driver.Conn) bool { + topologyTableExistQuery := "SELECT 'rds_tools.show_topology'::regproc" + return utils.GetFirstRowFromQuery(conn, topologyTableExistQuery) != nil +} + type PgTopologyAwareDatabaseDialect struct { PgDatabaseDialect } @@ -210,27 +219,25 @@ func (m *PgTopologyAwareDatabaseDialect) GetWriterHostName(conn driver.Conn) (st func (m *PgTopologyAwareDatabaseDialect) GetHostListProvider( props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { - return m.getTopologyAwareHostListProvider(m, props, initialDsn, hostListProviderService, pluginService) + return m.getTopologyAwareHostListProvider(m, props, hostListProviderService, pluginService) } func (m *PgTopologyAwareDatabaseDialect) getTopologyAwareHostListProvider( dialect TopologyAwareDialect, props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { pluginsProp := property_util.GetVerifiedWrapperPropertyValue[string](props, property_util.PLUGINS) if strings.Contains(pluginsProp, "failover") { slog.Debug(error_util.GetMessage("DatabaseDialect.usingMonitoringHostListProvider")) - return HostListProvider(NewMonitoringRdsHostListProvider(hostListProviderService, dialect, props, initialDsn, pluginService)) + return NewMonitoringRdsHostListProvider(hostListProviderService, dialect, props, pluginService) } slog.Debug(error_util.GetMessage("DatabaseDialect.usingRdsHostListProvider")) - return HostListProvider(NewRdsHostListProvider(hostListProviderService, dialect, props, initialDsn, nil, nil)) + return NewRdsHostListProvider(hostListProviderService, dialect, props, nil, nil) } type AuroraPgDatabaseDialect struct { @@ -276,7 +283,7 @@ func (m *AuroraPgDatabaseDialect) GetTopology(conn driver.Conn, provider HostLis defer rows.Close() } - hosts := []*host_info_util.HostInfo{} + var hosts []*host_info_util.HostInfo if rows == nil { // Query returned an empty host list, no processing required. return hosts, nil @@ -339,16 +346,26 @@ func (m *AuroraPgDatabaseDialect) GetWriterHostName(conn driver.Conn) (string, e func (m *AuroraPgDatabaseDialect) GetHostListProvider( props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { - return m.getTopologyAwareHostListProvider(m, props, initialDsn, hostListProviderService, pluginService) + return m.getTopologyAwareHostListProvider(m, props, hostListProviderService, pluginService) } func (m *AuroraPgDatabaseDialect) GetLimitlessRouterEndpointQuery() string { return "select router_endpoint, load from aurora_limitless_router_endpoints()" } +func (m *AuroraPgDatabaseDialect) GetBlueGreenStatus(conn driver.Conn) []BlueGreenResult { + bgStatusQuery := "SELECT version, endpoint, port, role, status FROM get_blue_green_fast_switchover_metadata(" + + "'aws_advanced_go_wrapper-" + driver_info.AWS_ADVANCED_GO_WRAPPER_VERSION + "')" + return pgGetBlueGreenStatus(conn, bgStatusQuery) +} + +func (m *AuroraPgDatabaseDialect) IsBlueGreenStatusAvailable(conn driver.Conn) bool { + topologyTableExistQuery := "SELECT 'get_blue_green_fast_switchover_metadata'::regproc" + return utils.GetFirstRowFromQuery(conn, topologyTableExistQuery) != nil +} + type RdsMultiAzClusterPgDatabaseDialect struct { PgTopologyAwareDatabaseDialect } @@ -401,7 +418,7 @@ func (r *RdsMultiAzClusterPgDatabaseDialect) processTopologyQueryResults( provider HostListProvider, writerHostId string, rows driver.Rows) []*host_info_util.HostInfo { - hosts := []*host_info_util.HostInfo{} + var hosts []*host_info_util.HostInfo row := make([]driver.Value, len(rows.Columns())) err := rows.Next(row) for err == nil && len(row) > 1 { @@ -491,8 +508,11 @@ func (r *RdsMultiAzClusterPgDatabaseDialect) GetWriterHostName(conn driver.Conn) func (r *RdsMultiAzClusterPgDatabaseDialect) GetHostListProvider( props map[string]string, - initialDsn string, hostListProviderService HostListProviderService, pluginService PluginService) HostListProvider { - return r.getTopologyAwareHostListProvider(r, props, initialDsn, hostListProviderService, pluginService) + return r.getTopologyAwareHostListProvider(r, props, hostListProviderService, pluginService) +} + +func pgGetBlueGreenStatus(conn driver.Conn, query string) []BlueGreenResult { + return getBlueGreenStatus(conn, query, utils.PgConvertValToString) } diff --git a/awssql/driver_infrastructure/plugin_helpers.go b/awssql/driver_infrastructure/plugin_helpers.go index 99df81bc..1003491e 100644 --- a/awssql/driver_infrastructure/plugin_helpers.go +++ b/awssql/driver_infrastructure/plugin_helpers.go @@ -31,7 +31,7 @@ type PluginConnectFunc func(plugin ConnectionPlugin, props map[string]string, ta type HostListProviderService interface { IsStaticHostListProvider() bool - CreateHostListProvider(props map[string]string, dsn string) HostListProvider + CreateHostListProvider(props map[string]string) HostListProvider GetHostListProvider() HostListProvider SetHostListProvider(hostListProvider HostListProvider) SetInitialConnectionHostInfo(info *host_info_util.HostInfo) @@ -55,7 +55,7 @@ type PluginService interface { SetInTransaction(inTransaction bool) GetCurrentTx() driver.Tx SetCurrentTx(driver.Tx) - CreateHostListProvider(props map[string]string, dsn string) HostListProvider + CreateHostListProvider(props map[string]string) HostListProvider SetHostListProvider(hostListProvider HostListProvider) SetInitialConnectionHostInfo(info *host_info_util.HostInfo) IsStaticHostListProvider() bool @@ -81,6 +81,9 @@ type PluginService interface { SetTelemetryContext(ctx context.Context) UpdateState(sql string, methodArgs ...any) IsReadOnly() bool + GetBgStatus(id string) (BlueGreenStatus, bool) + SetBgStatus(status BlueGreenStatus, id string) + IsPluginInUse(pluginName string) bool } type PluginServiceProvider func( @@ -91,7 +94,7 @@ type PluginServiceProvider func( type PluginManager interface { Init(pluginService PluginService, plugins []ConnectionPlugin) error - InitHostProvider(initialUrl string, props map[string]string, hostListProviderService HostListProviderService) error + InitHostProvider(props map[string]string, hostListProviderService HostListProviderService) error Connect(hostInfo *host_info_util.HostInfo, props map[string]string, isInitialConnection bool, pluginToSkip ConnectionPlugin) (driver.Conn, error) ForceConnect(hostInfo *host_info_util.HostInfo, props map[string]string, isInitialConnection bool) (driver.Conn, error) Execute(connInvokedOn driver.Conn, name string, methodFunc ExecuteFunc, methodArgs ...any) ( @@ -112,7 +115,9 @@ type PluginManager interface { GetTelemetryContext() context.Context GetTelemetryFactory() telemetry.TelemetryFactory SetTelemetryContext(ctx context.Context) + IsPluginInUse(pluginName string) bool ReleaseResources() + UnwrapPlugin(pluginCode string) ConnectionPlugin } type PluginManagerProvider func( @@ -125,7 +130,7 @@ type CanReleaseResources interface { ReleaseResources() } -// This cleans up all long standing caches. To be called at the end of program, not each time a Conn is closed. +// This cleans up all long-standing caches. To be called at the end of program, not each time a Conn is closed. func ClearCaches() { if knownEndpointDialectsCache != nil { knownEndpointDialectsCache.Clear() diff --git a/awssql/driver_infrastructure/rds_host_list_provider.go b/awssql/driver_infrastructure/rds_host_list_provider.go index 1120a891..e5e63b34 100644 --- a/awssql/driver_infrastructure/rds_host_list_provider.go +++ b/awssql/driver_infrastructure/rds_host_list_provider.go @@ -37,14 +37,12 @@ func NewRdsHostListProvider( hostListProviderService HostListProviderService, databaseDialect TopologyAwareDialect, properties map[string]string, - originalDsn string, queryForTopologyFunc func(conn driver.Conn) ([]*host_info_util.HostInfo, error), clusterIdChangedFunc func(oldClusterId string)) *RdsHostListProvider { r := &RdsHostListProvider{ hostListProviderService: hostListProviderService, databaseDialect: databaseDialect, properties: properties, - originalDsn: originalDsn, isInitialized: false, } if queryForTopologyFunc == nil { @@ -65,7 +63,6 @@ type RdsHostListProvider struct { hostListProviderService HostListProviderService databaseDialect TopologyAwareDialect properties map[string]string - originalDsn string isInitialized bool // The following properties are initialized from the above in init(). initialHostList []*host_info_util.HostInfo @@ -87,7 +84,7 @@ func (r *RdsHostListProvider) init() { } refreshRateInt := property_util.GetRefreshRateValue(r.properties, property_util.CLUSTER_TOPOLOGY_REFRESH_RATE_MS) r.refreshRateNanos = time.Millisecond * time.Duration(refreshRateInt) - hostListFromDsn, err := utils.GetHostsFromDsn(r.originalDsn, false) + hostListFromDsn, err := utils.GetHostsFromProps(r.properties, false) if err != nil || len(hostListFromDsn) == 0 { return } @@ -158,7 +155,7 @@ func (r *RdsHostListProvider) ForceRefresh(conn driver.Conn) ([]*host_info_util. if err != nil { return nil, err } - slog.Info(utils.LogTopology(hosts, "From ForceRefresh")) + slog.Debug(utils.LogTopology(hosts, "From ForceRefresh")) return hosts, nil } diff --git a/awssql/host_info_util/host_info.go b/awssql/host_info_util/host_info.go index e417da56..ddd22b8e 100644 --- a/awssql/host_info_util/host_info.go +++ b/awssql/host_info_util/host_info.go @@ -57,8 +57,10 @@ type HostInfo struct { } func (hostInfo *HostInfo) AddAlias(alias string) { - hostInfo.Aliases[alias] = true - hostInfo.AllAliases[alias] = true + if alias != "" && hostInfo != nil { + hostInfo.Aliases[alias] = true + hostInfo.AllAliases[alias] = true + } } func (hostInfo *HostInfo) ResetAliases() { @@ -72,7 +74,17 @@ func (hostInfo *HostInfo) GetUrl() string { return hostInfo.GetHostAndPort() + "/" } +func (hostInfo *HostInfo) GetHost() string { + if hostInfo == nil { + return "" + } + return hostInfo.Host +} + func (hostInfo *HostInfo) GetHostAndPort() string { + if hostInfo == nil { + return "" + } if hostInfo.IsPortSpecified() { return hostInfo.Host + ":" + strconv.Itoa(hostInfo.Port) } @@ -80,7 +92,7 @@ func (hostInfo *HostInfo) GetHostAndPort() string { } func (hostInfo *HostInfo) IsPortSpecified() bool { - return hostInfo.Port != HOST_NO_PORT + return hostInfo != nil && hostInfo.Port != HOST_NO_PORT } func (hostInfo *HostInfo) Equals(host *HostInfo) bool { @@ -177,6 +189,19 @@ func (hostInfoBuilder *HostInfoBuilder) SetLastUpdateTime(lastUpdateTime time.Ti return hostInfoBuilder } +func (hostInfoBuilder *HostInfoBuilder) CopyFrom(hostInfo *HostInfo) *HostInfoBuilder { + if hostInfo != nil { + hostInfoBuilder.host = hostInfo.Host + hostInfoBuilder.hostId = hostInfo.HostId + hostInfoBuilder.port = hostInfo.Port + hostInfoBuilder.availability = hostInfo.Availability + hostInfoBuilder.role = hostInfo.Role + hostInfoBuilder.weight = hostInfo.Weight + hostInfoBuilder.lastUpdateTime = hostInfo.LastUpdateTime + } + return hostInfoBuilder +} + func (hostInfoBuilder *HostInfoBuilder) Build() (hostInfo *HostInfo, err error) { err = hostInfoBuilder.checkHostIsSet() if err != nil { diff --git a/awssql/host_info_util/host_info_util.go b/awssql/host_info_util/host_info_util.go index 29db2231..5065c8cf 100644 --- a/awssql/host_info_util/host_info_util.go +++ b/awssql/host_info_util/host_info_util.go @@ -16,6 +16,12 @@ package host_info_util +import ( + "fmt" + "slices" + "strings" +) + func AreHostListsEqual(s1 []*HostInfo, s2 []*HostInfo) bool { if len(s1) != len(s2) { return false @@ -32,9 +38,50 @@ func AreHostListsEqual(s1 []*HostInfo, s2 []*HostInfo) bool { func GetWriter(hosts []*HostInfo) *HostInfo { for _, host := range hosts { - if host.Role == WRITER { + if host != nil && host.Role == WRITER { return host } } return nil } + +func GetReaders(hosts []*HostInfo) []*HostInfo { + readerHosts := make([]*HostInfo, 0, len(hosts)) + for _, host := range hosts { + if host != nil && host.Role == READER { + readerHosts = append(readerHosts, host) + } + } + slices.SortFunc(readerHosts, func(i, j *HostInfo) int { + return strings.Compare(i.GetHost(), j.GetHost()) + }) + return readerHosts +} + +func GetHostAndPort(host string, port int) string { + if port > 0 && host != "" { + return fmt.Sprintf("%s:%d", host, port) + } + return host +} + +func HaveNoHostsInCommon(hosts1 []*HostInfo, hosts2 []*HostInfo) bool { + var mapSlice, checkSlice []*HostInfo + if len(hosts1) <= len(hosts2) { + mapSlice, checkSlice = hosts1, hosts2 + } else { + mapSlice, checkSlice = hosts2, hosts1 + } + + checkMap := make(map[string]int, len(mapSlice)) + for _, host := range mapSlice { + checkMap[host.Host] = 0 + } + + for _, host := range checkSlice { + if _, exists := checkMap[host.Host]; exists { + return false + } + } + return true +} diff --git a/awssql/plugin_helpers/plugin_manager.go b/awssql/plugin_helpers/plugin_manager.go index 3669a0e5..084aa7ce 100644 --- a/awssql/plugin_helpers/plugin_manager.go +++ b/awssql/plugin_helpers/plugin_manager.go @@ -109,7 +109,7 @@ type PluginManagerImpl struct { pluginService driver_infrastructure.PluginService connProviderManager driver_infrastructure.ConnectionProviderManager props map[string]string - pluginFuncMap map[string]PluginChain + pluginFuncMap *utils.RWMap[PluginChain] plugins []driver_infrastructure.ConnectionPlugin telemetryFactory telemetry.TelemetryFactory telemetryCtx context.Context @@ -121,7 +121,7 @@ func NewPluginManagerImpl( props map[string]string, connProviderManager driver_infrastructure.ConnectionProviderManager, telemetryFactory telemetry.TelemetryFactory) driver_infrastructure.PluginManager { - pluginFuncMap := make(map[string]PluginChain) + pluginFuncMap := utils.NewRWMap[PluginChain]() return &PluginManagerImpl{ targetDriver: targetDriver, props: props, @@ -140,7 +140,6 @@ func (pluginManager *PluginManagerImpl) Init( } func (pluginManager *PluginManagerImpl) InitHostProvider( - initialUrl string, props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService) error { parentCtx := pluginManager.GetTelemetryContext() @@ -163,7 +162,7 @@ func (pluginManager *PluginManagerImpl) InitHostProvider( _, _, _, err := targetFunc() return err } - err := plugin.InitHostProvider(initialUrl, props, hostListProviderService, initFunc) + err := plugin.InitHostProvider(props, hostListProviderService, initFunc) if err != nil { return nil, nil, false, err } @@ -209,7 +208,7 @@ func (pluginManager *PluginManagerImpl) Connect( targetFunc := func(props map[string]string) (driver.Conn, error) { return nil, error_util.ShouldNotBeCalledError } - return pluginManager.connectWithSubscribedPlugins(CONNECT_METHOD, pluginFunc, targetFunc, pluginToSkip) + return pluginManager.connectWithSubscribedPlugins(CONNECT_METHOD, pluginFunc, targetFunc, props, pluginToSkip) } func (pluginManager *PluginManagerImpl) ForceConnect( @@ -225,7 +224,7 @@ func (pluginManager *PluginManagerImpl) ForceConnect( targetFunc := func(props map[string]string) (driver.Conn, error) { return nil, error_util.ShouldNotBeCalledError } - return pluginManager.connectWithSubscribedPlugins(FORCE_CONNECT_METHOD, pluginFunc, targetFunc, nil) + return pluginManager.connectWithSubscribedPlugins(FORCE_CONNECT_METHOD, pluginFunc, targetFunc, props, nil) } func (pluginManager *PluginManagerImpl) Execute( @@ -266,31 +265,23 @@ func (pluginManager *PluginManagerImpl) executeWithSubscribedPlugins( methodName string, pluginFunc driver_infrastructure.PluginExecFunc, targetFunc driver_infrastructure.ExecuteFunc) (any, any, bool, error) { - chain, ok := pluginManager.pluginFuncMap[methodName] - if !ok { - chain = pluginManager.makePluginChain(methodName, true, nil) - pluginManager.pluginFuncMap[methodName] = chain - } + chain := pluginManager.pluginFuncMap.ComputeIfAbsent(methodName, func() PluginChain { + return pluginManager.makePluginChain(methodName, true, nil) + }) return chain.Execute(pluginFunc, targetFunc) } -func (pluginManager *PluginManagerImpl) connectWithSubscribedPlugins( - methodName string, - pluginFunc driver_infrastructure.PluginConnectFunc, - targetFunc driver_infrastructure.ConnectFunc, - pluginToSkip driver_infrastructure.ConnectionPlugin) (driver.Conn, error) { +func (pluginManager *PluginManagerImpl) connectWithSubscribedPlugins(methodName string, pluginFunc driver_infrastructure.PluginConnectFunc, + targetFunc driver_infrastructure.ConnectFunc, props map[string]string, pluginToSkip driver_infrastructure.ConnectionPlugin) (driver.Conn, error) { var chain PluginChain if pluginToSkip == nil { - ok := false - chain, ok = pluginManager.pluginFuncMap[methodName] - if !ok { - chain = pluginManager.makePluginChain(methodName, false, nil) - pluginManager.pluginFuncMap[methodName] = chain - } + chain = pluginManager.pluginFuncMap.ComputeIfAbsent(methodName, func() PluginChain { + return pluginManager.makePluginChain(methodName, false, nil) + }) } else { chain = pluginManager.makePluginChain(methodName, false, pluginToSkip) } - return chain.Connect(pluginFunc, pluginManager.props, targetFunc) + return chain.Connect(pluginFunc, props, targetFunc) } func (pluginManager *PluginManagerImpl) makePluginChain( @@ -462,3 +453,21 @@ func (pluginManager *PluginManagerImpl) SetTelemetryContext(ctx context.Context) defer pluginManager.telemetryCtxLock.Unlock() pluginManager.telemetryCtx = ctx } + +func (pluginManager *PluginManagerImpl) IsPluginInUse(pluginCode string) bool { + for _, plugin := range pluginManager.plugins { + if plugin.GetPluginCode() == pluginCode { + return true + } + } + return false +} + +func (pluginManager *PluginManagerImpl) UnwrapPlugin(pluginCode string) driver_infrastructure.ConnectionPlugin { + for _, plugin := range pluginManager.plugins { + if plugin.GetPluginCode() == pluginCode { + return plugin + } + } + return nil +} diff --git a/awssql/plugin_helpers/plugin_service.go b/awssql/plugin_helpers/plugin_service.go index 23e0c8b0..d4734100 100644 --- a/awssql/plugin_helpers/plugin_service.go +++ b/awssql/plugin_helpers/plugin_service.go @@ -19,21 +19,24 @@ package plugin_helpers import ( "context" "database/sql/driver" + "fmt" "log/slog" "slices" + "strings" "time" - "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" - "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" "github.com/aws/aws-advanced-go-wrapper/awssql/utils" "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" ) -var hostAvailabilityExpiringCache *utils.CacheMap[host_info_util.HostAvailability] = utils.NewCache[host_info_util.HostAvailability]() -var DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO time.Duration = 5 * time.Minute +var hostAvailabilityExpiringCache = utils.NewCache[host_info_util.HostAvailability]() +var statusesExpiringCache = utils.NewCache[driver_infrastructure.BlueGreenStatus]() +var DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO = 5 * time.Minute +var DEFAULT_STATUS_CACHE_EXPIRE_NANO = 60 * time.Minute type PluginServiceImpl struct { pluginManager driver_infrastructure.PluginManager @@ -88,8 +91,8 @@ func (p *PluginServiceImpl) SetHostListProvider(hostListProvider driver_infrastr p.hostListProvider = hostListProvider } -func (p *PluginServiceImpl) CreateHostListProvider(props map[string]string, dsn string) driver_infrastructure.HostListProvider { - return p.GetDialect().GetHostListProvider(props, dsn, driver_infrastructure.HostListProviderService(p), p) +func (p *PluginServiceImpl) CreateHostListProvider(props map[string]string) driver_infrastructure.HostListProvider { + return p.GetDialect().GetHostListProvider(props, driver_infrastructure.HostListProviderService(p), p) } func (p *PluginServiceImpl) GetDialect() driver_infrastructure.DatabaseDialect { @@ -114,7 +117,7 @@ func (p *PluginServiceImpl) UpdateDialect(conn driver.Conn) { return } p.dialect = newDialect - p.SetHostListProvider(p.CreateHostListProvider(p.props, p.originalDsn)) + p.SetHostListProvider(p.CreateHostListProvider(p.props)) } func (p *PluginServiceImpl) GetCurrentConnection() driver.Conn { @@ -184,7 +187,7 @@ func (p *PluginServiceImpl) SetCurrentConnection( shouldCloseConnection := connectionObjectHasChanged && !p.GetTargetDriverDialect().IsClosed(oldConnection) && !preserve if shouldCloseConnection { _ = p.sessionStateService.ApplyPristineSessionState(oldConnection) - oldConnection.Close() + _ = oldConnection.Close() } } } @@ -368,17 +371,17 @@ func (p *PluginServiceImpl) updateHostListIfNeeded(updatedHostList []*host_info_ } func (p *PluginServiceImpl) setHostList(oldHosts []*host_info_util.HostInfo, newHosts []*host_info_util.HostInfo) { - var oldHostMap map[string]*host_info_util.HostInfo = map[string]*host_info_util.HostInfo{} + var oldHostMap = map[string]*host_info_util.HostInfo{} for _, host := range oldHosts { oldHostMap[host.GetHostAndPort()] = host } - var newHostMap map[string]*host_info_util.HostInfo = map[string]*host_info_util.HostInfo{} + var newHostMap = map[string]*host_info_util.HostInfo{} for _, host := range newHosts { newHostMap[host.GetHostAndPort()] = host } - var changes map[string]map[driver_infrastructure.HostChangeOptions]bool = map[string]map[driver_infrastructure.HostChangeOptions]bool{} + var changes = map[string]map[driver_infrastructure.HostChangeOptions]bool{} for hostKey, hostInfo := range oldHostMap { correspondingNewHost, ok := newHostMap[hostKey] if !ok || correspondingNewHost.IsNil() { @@ -565,7 +568,7 @@ func (p *PluginServiceImpl) UpdateState(sql string, methodArgs ...any) { func (p *PluginServiceImpl) ReleaseResources() { slog.Debug(error_util.GetMessage("PluginServiceImpl.releaseResources")) if p.currentConnection != nil { - (*p.currentConnection).Close() // Ignore any error. + _ = (*p.currentConnection).Close() // Ignore any error. p.currentConnection = nil } @@ -586,9 +589,36 @@ func (p *PluginServiceImpl) IsReadOnly() bool { return *readOnly } -// This cleans up all long standing caches. To be called at the end of program, not each time a Conn is closed. +func (p *PluginServiceImpl) GetBgStatus(id string) (driver_infrastructure.BlueGreenStatus, bool) { + return statusesExpiringCache.Get(p.getStatusCacheKey(id)) +} + +func (p *PluginServiceImpl) SetBgStatus(status driver_infrastructure.BlueGreenStatus, id string) { + cacheKey := p.getStatusCacheKey(id) + if status.IsZero() { + statusesExpiringCache.Remove(cacheKey) + } else { + statusesExpiringCache.Put(cacheKey, status, DEFAULT_STATUS_CACHE_EXPIRE_NANO) + } +} + +func (p *PluginServiceImpl) getStatusCacheKey(id string) string { + if id != "" { + id = strings.ToLower(strings.TrimSpace(id)) + } + return fmt.Sprintf("%s::%s", id, "BlueGreenStatus") +} + +func (p *PluginServiceImpl) IsPluginInUse(pluginName string) bool { + return p.pluginManager.IsPluginInUse(pluginName) +} + +// This cleans up all long-standing caches. To be called at the end of program, not each time a Conn is closed. func ClearCaches() { if hostAvailabilityExpiringCache != nil { hostAvailabilityExpiringCache.Clear() } + if statusesExpiringCache != nil { + statusesExpiringCache.Clear() + } } diff --git a/awssql/plugins/base_connection_plugin.go b/awssql/plugins/base_connection_plugin.go index d03fc7c1..bd80e5e8 100644 --- a/awssql/plugins/base_connection_plugin.go +++ b/awssql/plugins/base_connection_plugin.go @@ -19,6 +19,7 @@ package plugins import ( "database/sql/driver" "fmt" + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" @@ -27,6 +28,10 @@ import ( type BaseConnectionPlugin struct { } +func (b BaseConnectionPlugin) GetPluginCode() string { + return "" +} + func (b BaseConnectionPlugin) GetSubscribedMethods() []string { return []string{} } @@ -76,7 +81,6 @@ func (b BaseConnectionPlugin) NotifyHostListChanged(changes map[string]map[drive } func (b BaseConnectionPlugin) InitHostProvider( - initialUrl string, props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, initHostProviderFunc func() error) error { diff --git a/awssql/plugins/bg/base_routing.go b/awssql/plugins/bg/base_routing.go new file mode 100644 index 00000000..d02ae381 --- /dev/null +++ b/awssql/plugins/bg/base_routing.go @@ -0,0 +1,83 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package bg + +import ( + "fmt" + "strings" + "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" +) + +const SLEEP_CHUNK_DURATION = 50 * time.Millisecond +const TELEMETRY_SWITCHOVER = "Blue/Green switchover" +const SLEEP_TIME_DURATION = 100 * time.Millisecond + +type BaseRouting struct { + hostAndPort string + role driver_infrastructure.BlueGreenRole +} + +func NewBaseRouting(hostAndPort string, role driver_infrastructure.BlueGreenRole) BaseRouting { + return BaseRouting{ + hostAndPort: strings.ToLower(hostAndPort), + role: role, + } +} + +func (b BaseRouting) Delay(delayNanos time.Duration, bgStatus driver_infrastructure.BlueGreenStatus, + pluginService driver_infrastructure.PluginService, bgId string) { + endTime := time.Now().Add(delayNanos) + minDelay := min(delayNanos, SLEEP_CHUNK_DURATION) + + if bgStatus.IsZero() { + time.Sleep(delayNanos) + } else { + status, ok := pluginService.GetBgStatus(bgId) + for ok && bgStatus.MatchIdPhaseAndLen(status) && time.Now().Before(endTime) { + time.Sleep(minDelay) + } + } +} + +func (b BaseRouting) IsMatch(hostInfo *host_info_util.HostInfo, hostRole driver_infrastructure.BlueGreenRole) bool { + hostAndPort := "" + if !hostInfo.IsNil() { + hostAndPort = strings.ToLower(hostInfo.GetHostAndPort()) + } + return (b.hostAndPort == "" || b.hostAndPort == hostAndPort) && (b.role.IsZero() || b.role == hostRole) +} + +func (b BaseRouting) String() string { + hostAndPort := "" + if b.hostAndPort != "" { + hostAndPort = b.hostAndPort + } + + role := "" + if !b.role.IsZero() { + role = b.role.String() + } + + return fmt.Sprintf("%s [%s, %s]", + "Routing", + hostAndPort, + role, + ) +} diff --git a/awssql/plugins/bg/bg_info_helpers.go b/awssql/plugins/bg/bg_info_helpers.go new file mode 100644 index 00000000..75480610 --- /dev/null +++ b/awssql/plugins/bg/bg_info_helpers.go @@ -0,0 +1,40 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package bg + +import ( + "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" +) + +type StatusInfo struct { + version string + endpoint string + port int + phase driver_infrastructure.BlueGreenPhase + role driver_infrastructure.BlueGreenRole +} + +func (s *StatusInfo) IsZero() bool { + return s == nil || (s.version == "" && s.endpoint == "" && s.port == 0) +} + +type PhaseTimeInfo struct { + Timestamp time.Time + Phase driver_infrastructure.BlueGreenPhase +} diff --git a/awssql/plugins/bg/bg_interim_status.go b/awssql/plugins/bg/bg_interim_status.go new file mode 100644 index 00000000..edf17941 --- /dev/null +++ b/awssql/plugins/bg/bg_interim_status.go @@ -0,0 +1,183 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package bg + +import ( + "fmt" + "hash/fnv" + "log/slog" + "sort" + "strconv" + "strings" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" +) + +type BlueGreenInterimStatus struct { + phase driver_infrastructure.BlueGreenPhase + version string + port int + startTopology []*host_info_util.HostInfo + endTopology []*host_info_util.HostInfo + startIpAddressesByHostMap map[string]string + currentIpAddressesByHostMap map[string]string + hostNames map[string]bool + allStartTopologyIpChanged bool + allStartTopologyEndpointsRemoved bool + allTopologyChanged bool +} + +func (b *BlueGreenInterimStatus) IsZero() bool { + return b == nil || (b.version == "" && b.port == 0 && b.phase.IsZero()) +} + +func (b *BlueGreenInterimStatus) String() string { + var currentIpMapParts []string + for key, value := range b.currentIpAddressesByHostMap { + currentIpMapParts = append(currentIpMapParts, fmt.Sprintf("%s -> %s", key, value)) + } + currentIpMap := strings.Join(currentIpMapParts, "\n ") + + var startIpMapParts []string + for key, value := range b.startIpAddressesByHostMap { + startIpMapParts = append(startIpMapParts, fmt.Sprintf("%s -> %s", key, value)) + } + startIpMap := strings.Join(startIpMapParts, "\n ") + + allHostNamesStr := strings.Join(utils.AllKeys(b.hostNames), "\n ") + + startTopologyStr := utils.LogTopology(b.startTopology, "") + endTopologyStr := utils.LogTopology(b.endTopology, "") + + phaseStr := "" + if b.phase.GetName() != "" { + phaseStr = b.phase.GetName() + } + + emptyOrValue := func(s string) string { + if strings.TrimSpace(s) == "" { + return "-" + } + return s + } + + return fmt.Sprintf("BlueGreenInterimStatus [\n"+ + " phase %s, \n"+ + " version '%s', \n"+ + " port %d, \n"+ + " hostNames:\n"+ + " %s \n"+ + " Start %s \n"+ + " start IP map:\n"+ + " %s \n"+ + " Current %s \n"+ + " current IP map:\n"+ + " %s \n"+ + " allStartTopologyIpChanged: %t \n"+ + " allStartTopologyEndpointsRemoved: %t \n"+ + " allTopologyChanged: %t \n"+ + "]", + phaseStr, + b.version, + b.port, + emptyOrValue(allHostNamesStr), + emptyOrValue(startTopologyStr), + emptyOrValue(startIpMap), + emptyOrValue(endTopologyStr), + emptyOrValue(currentIpMap), + b.allStartTopologyIpChanged, + b.allStartTopologyEndpointsRemoved, + b.allTopologyChanged) +} + +func (b *BlueGreenInterimStatus) GetCustomHashCode() uint64 { + result := getValueHash(1, b.phase.GetName()) + result = getValueHash(result, b.version) + result = getValueHash(result, strconv.Itoa(b.port)) + result = getValueHash(result, strconv.FormatBool(b.allStartTopologyIpChanged)) + result = getValueHash(result, strconv.FormatBool(b.allStartTopologyEndpointsRemoved)) + result = getValueHash(result, strconv.FormatBool(b.allTopologyChanged)) + + result = getValueHash(result, b.getHostNamesString()) + result = getValueHash(result, b.getTopologyString(b.startTopology)) + result = getValueHash(result, b.getTopologyString(b.endTopology)) + result = getValueHash(result, b.getIpMapString(b.startIpAddressesByHostMap)) + result = getValueHash(result, b.getIpMapString(b.currentIpAddressesByHostMap)) + + return result +} + +func (b *BlueGreenInterimStatus) getHostNamesString() string { + if len(b.hostNames) == 0 { + return "" + } + + // Extract keys from map and sort them + keys := make([]string, 0, len(b.hostNames)) + for key := range b.hostNames { + keys = append(keys, key) + } + sort.Strings(keys) + + return strings.Join(keys, ",") +} + +func (b *BlueGreenInterimStatus) getTopologyString(topology []*host_info_util.HostInfo) string { + if len(topology) == 0 { + return "" + } + + // Map each HostInfo to string and sort + hostStrings := make([]string, len(topology)) + for i, hostInfo := range topology { + hostStrings[i] = hostInfo.GetHostAndPort() + string(hostInfo.Role) + } + sort.Strings(hostStrings) + + return strings.Join(hostStrings, ",") +} + +func (b *BlueGreenInterimStatus) getIpMapString(ipMap map[string]string) string { + if len(ipMap) == 0 { + return "" + } + + // Convert map entries to strings and sort + entries := make([]string, 0, len(ipMap)) + for key, value := range ipMap { + entries = append(entries, key+value) + } + sort.Strings(entries) + + return strings.Join(entries, ",") +} + +func getValueHash(currentHash uint64, val string) uint64 { + // Use FNV-1a hash algorithm for string hashing + h := fnv.New64a() + _, err := h.Write([]byte(val)) + if err != nil { + slog.Warn(error_util.GetMessage("BlueGreenDeployment.errorGeneratingHash", err)) + return 0 + } + stringHash := h.Sum64() + + return currentHash*31 + stringHash +} diff --git a/awssql/plugins/bg/bg_plugin.go b/awssql/plugins/bg/bg_plugin.go new file mode 100644 index 00000000..d5db77b9 --- /dev/null +++ b/awssql/plugins/bg/bg_plugin.go @@ -0,0 +1,255 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package bg + +import ( + "database/sql/driver" + "slices" + "sync/atomic" + "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugin_helpers" + "github.com/aws/aws-advanced-go-wrapper/awssql/plugins" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" +) + +var bgSubscribedMethods = append(utils.NETWORK_BOUND_METHODS, plugin_helpers.CONNECT_METHOD) +var providers = utils.NewRWMapWithDisposalFunc(func(provider *BlueGreenStatusProvider) bool { + if provider != nil { + provider.ClearMonitors() + } + return true +}) + +type BlueGreenPluginFactory struct{} + +func (b *BlueGreenPluginFactory) ClearCaches() { + providers.Clear() +} + +func (b *BlueGreenPluginFactory) GetInstance(pluginService driver_infrastructure.PluginService, props map[string]string) (driver_infrastructure.ConnectionPlugin, error) { + return NewBlueGreenPlugin(pluginService, props) +} + +func NewBlueGreenPluginFactory() driver_infrastructure.ConnectionPluginFactory { + return &BlueGreenPluginFactory{} +} + +type BlueGreenPlugin struct { + bgId string + bgStatus driver_infrastructure.BlueGreenStatus + bgProviderSupplier BlueGreenProviderSupplier + isIamInUse bool + pluginService driver_infrastructure.PluginService + props map[string]string + startTime atomic.Int64 + endTime atomic.Int64 + plugins.BaseConnectionPlugin +} + +func NewBlueGreenPlugin(pluginService driver_infrastructure.PluginService, + props map[string]string) (driver_infrastructure.ConnectionPlugin, error) { + bgId := property_util.GetVerifiedWrapperPropertyValue[string](props, property_util.BGD_ID) + if bgId == "" { + return nil, error_util.NewGenericAwsWrapperError(error_util.GetMessage("BlueGreenDeployment.bgIdRequired")) + } + return &BlueGreenPlugin{ + bgId: bgId, + props: props, + pluginService: pluginService, + bgProviderSupplier: NewBlueGreenStatusProvider, + }, nil +} + +func (b *BlueGreenPlugin) GetPluginCode() string { + return driver_infrastructure.BLUE_GREEN_PLUGIN_CODE +} + +func (b *BlueGreenPlugin) GetSubscribedMethods() []string { + return bgSubscribedMethods +} + +func (b *BlueGreenPlugin) Connect( + hostInfo *host_info_util.HostInfo, + props map[string]string, + isInitialConnection bool, + connectFunc driver_infrastructure.ConnectFunc) (conn driver.Conn, err error) { + b.resetRoutingTimeNano() + defer func() { + if b.startTime.Load() > 0 { + b.endTime.CompareAndSwap(0, time.Now().Unix()) + } + }() + + bgStatus, ok := b.pluginService.GetBgStatus(b.bgId) + b.bgStatus = bgStatus + + if b.bgStatus.IsZero() || !ok { + // Connection does not require BG logic. + return b.regularOpenConnection(connectFunc, isInitialConnection) + } + + if isInitialConnection { + // Upon initial connection, mark whether iam is in use. + b.isIamInUse = b.pluginService.IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE) + } + + hostRole, ok := b.bgStatus.GetRole(hostInfo) + + if !ok || hostRole.IsZero() { + // Connection to a host that is not participating in BG switchover. + return b.regularOpenConnection(connectFunc, isInitialConnection) + } + + matchingRoutes := utils.FilterSlice(b.bgStatus.GetConnectRoutings(), func(r driver_infrastructure.ConnectRouting) bool { + return r.IsMatch(hostInfo, hostRole) + }) + + if len(matchingRoutes) == 0 { + return b.regularOpenConnection(connectFunc, isInitialConnection) + } + + b.startTime.Store(time.Now().UnixNano()) + routing := matchingRoutes[0] + for routing != nil && conn == nil { + conn, err = routing.Apply(b, hostInfo, props, isInitialConnection, b.pluginService) + if conn == nil { + b.bgStatus, ok = b.pluginService.GetBgStatus(b.bgId) + if !b.bgStatus.IsZero() && ok { + matchingRoutes := utils.FilterSlice(b.bgStatus.GetConnectRoutings(), func(r driver_infrastructure.ConnectRouting) bool { + return r.IsMatch(hostInfo, hostRole) + }) + + if len(matchingRoutes) != 0 { + routing = matchingRoutes[0] + continue + } + } + routing = nil + } + } + + if conn == nil { + conn, err = connectFunc(props) + } + + if isInitialConnection { + // Provider should be initialized after connection is open and a dialect is properly identified. + b.initProvider() + } + return conn, err +} + +func (b *BlueGreenPlugin) Execute( + _ driver.Conn, + methodName string, + executeFunc driver_infrastructure.ExecuteFunc, + methodArgs ...any) (wrappedReturnValue any, wrappedReturnValue2 any, wrappedOk bool, wrappedErr error) { + b.resetRoutingTimeNano() + defer func() { + if b.startTime.Load() > 0 { + b.endTime.CompareAndSwap(0, time.Now().Unix()) + } + }() + b.initProvider() + if slices.Contains(utils.CLOSING_METHODS, methodName) { + return executeFunc() + } + bgStatus, ok := b.pluginService.GetBgStatus(b.bgId) + b.bgStatus = bgStatus + if b.bgStatus.IsZero() || !ok { + return executeFunc() + } + currentHostInfo, err := b.pluginService.GetCurrentHostInfo() + hostRole, ok := b.bgStatus.GetRole(currentHostInfo) + if err != nil || !ok || hostRole.IsZero() { + return executeFunc() + } + + matchingRoutes := utils.FilterSlice(b.bgStatus.GetExecuteRoutings(), func(r driver_infrastructure.ExecuteRouting) bool { + return r.IsMatch(currentHostInfo, hostRole) + }) + + if len(matchingRoutes) == 0 { + return executeFunc() + } + b.startTime.Store(time.Now().UnixNano()) + routing := matchingRoutes[0] + result := driver_infrastructure.EMPTY_ROUTING_RESULT_HOLDER + for routing != nil && !result.IsPresent() { + result = routing.Apply(b, b.props, b.pluginService, methodName, executeFunc, methodArgs...) + if !result.IsPresent() { + b.bgStatus, ok = b.pluginService.GetBgStatus(b.bgId) + if b.bgStatus.IsZero() || !ok { + b.endTime.Store(time.Now().UnixNano()) + return executeFunc() + } + matchingRoutes := utils.FilterSlice(b.bgStatus.GetExecuteRoutings(), func(r driver_infrastructure.ExecuteRouting) bool { + return r.IsMatch(currentHostInfo, hostRole) + }) + + if len(matchingRoutes) != 0 { + routing = matchingRoutes[0] + continue + } + routing = nil + } + } + + b.endTime.Store(time.Now().UnixNano()) + if result.IsPresent() { + return result.GetResult() + } + return executeFunc() +} + +func (b *BlueGreenPlugin) initProvider() { + provider, ok := providers.Get(b.bgId) + if !ok || provider.isZero() { + provider = b.bgProviderSupplier(b.pluginService, b.props, b.bgId) + providers.Put(b.bgId, provider) + } +} + +func (b *BlueGreenPlugin) regularOpenConnection(connectFunc driver_infrastructure.ConnectFunc, isInitialConnection bool) (driver.Conn, error) { + conn, err := connectFunc(b.props) + if isInitialConnection { + // Provider should be initialized after connection is open and a dialect is properly identified. + b.initProvider() + } + return conn, err +} + +// For testing purposes only. +func (b *BlueGreenPlugin) GetHoldTimeNano() time.Duration { + if b.startTime.Load() == 0 { + return 0 * time.Nanosecond + } + if b.endTime.Load() == 0 { + return time.Duration(time.Now().Unix()-b.startTime.Load()) * time.Nanosecond + } + return time.Duration(b.endTime.Load()-b.startTime.Load()) * time.Nanosecond +} + +func (b *BlueGreenPlugin) resetRoutingTimeNano() { + b.startTime.Store(0) + b.endTime.Store(0) +} diff --git a/awssql/plugins/bg/bg_status_monitor.go b/awssql/plugins/bg/bg_status_monitor.go new file mode 100644 index 00000000..69b26a81 --- /dev/null +++ b/awssql/plugins/bg/bg_status_monitor.go @@ -0,0 +1,488 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package bg + +import ( + "database/sql/driver" + "fmt" + "log/slog" + "net" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" +) + +const BG_CLUSTER_ID = "941d00a8-8238-4f7d-bf59-771bff783a8e" +const DEFAULT_CHECK_INTERVAL = 5 * 60_000 // 5 minutes +const latestKnownVersion = "1.0" + +var knownVersions = []string{latestKnownVersion} + +type BlueGreenStatusMonitor struct { + role driver_infrastructure.BlueGreenRole + bgId string + initialHostInfo *host_info_util.HostInfo + hostListProvider driver_infrastructure.HostListProvider + pluginService driver_infrastructure.PluginService + blueGreenDialect driver_infrastructure.BlueGreenDialect + props map[string]string + statusCheckIntervalMap map[driver_infrastructure.BlueGreenIntervalRate]int + onBlueGreenStatusChangeFunc func(role driver_infrastructure.BlueGreenRole, interimStatus BlueGreenInterimStatus) + currentPhase driver_infrastructure.BlueGreenPhase + version string + port int + allStartTopologyIpChanged bool + allStartTopologyEndpointsRemoved bool + allTopologyChanged bool + startIpAddressesByHostMap *utils.RWMap[string] + currentIpAddressesByHostMap *utils.RWMap[string] + hostNames *utils.RWMap[bool] + startTopology []*host_info_util.HostInfo + currentTopology atomic.Pointer[[]*host_info_util.HostInfo] + connection atomic.Pointer[driver.Conn] + connectionHostInfo atomic.Pointer[host_info_util.HostInfo] + connectedIpAddress atomic.Value + intervalRate atomic.Int32 + collectedIpAddresses atomic.Bool + collectedTopology atomic.Bool + connectionHostInfoCorrect atomic.Bool + useIpAddress atomic.Bool + panicMode atomic.Bool + stop atomic.Bool + wg sync.WaitGroup +} + +func NewBlueGreenStatusMonitor(blueGreenRole driver_infrastructure.BlueGreenRole, bgdId string, hostInfo *host_info_util.HostInfo, + pluginService driver_infrastructure.PluginService, monitoringProps map[string]string, statusCheckIntervalMap map[driver_infrastructure.BlueGreenIntervalRate]int, + onBlueGreenStatusChangeFunc func(role driver_infrastructure.BlueGreenRole, interimStatus BlueGreenInterimStatus)) *BlueGreenStatusMonitor { + dialect, _ := pluginService.GetDialect().(driver_infrastructure.BlueGreenDialect) + monitor := BlueGreenStatusMonitor{ + role: blueGreenRole, + bgId: bgdId, + initialHostInfo: hostInfo, + pluginService: pluginService, + props: monitoringProps, + statusCheckIntervalMap: statusCheckIntervalMap, + onBlueGreenStatusChangeFunc: onBlueGreenStatusChangeFunc, + blueGreenDialect: dialect, + currentPhase: driver_infrastructure.NOT_CREATED, + version: "1.0", + port: -1, + startIpAddressesByHostMap: utils.NewRWMap[string](), + currentIpAddressesByHostMap: utils.NewRWMap[string](), + hostNames: utils.NewRWMap[bool](), + startTopology: []*host_info_util.HostInfo{}, + } + monitor.stop.Store(false) + monitor.panicMode.Store(true) + monitor.intervalRate.Store(int32(driver_infrastructure.BASELINE)) + + go monitor.runMonitoringLoop() + return &monitor +} + +func (b *BlueGreenStatusMonitor) runMonitoringLoop() { + b.wg.Add(1) + defer func() { + b.CloseConnection() + slog.Debug(error_util.GetMessage("BlueGreenDeployment.monitoringLoopCompleted", b.role)) + b.wg.Done() + }() + + for !b.stop.Load() { + var oldPhase = b.currentPhase + + b.OpenConnection() + b.CollectStatus() + _ = b.CollectTopology() + b.CollectHostIpAddresses() + b.UpdateIpAddressFlags() + + if !b.currentPhase.IsZero() && !b.currentPhase.Equals(oldPhase) { + slog.Warn(error_util.GetMessage("BlueGreenDeployment.statusChanged", b.role, b.currentPhase)) + } + + if !b.stop.Load() && b.onBlueGreenStatusChangeFunc != nil { + b.onBlueGreenStatusChangeFunc( + b.role, + BlueGreenInterimStatus{ + b.currentPhase, + b.version, + b.port, + b.startTopology, + b.GetCurrentTopology(), + b.startIpAddressesByHostMap.GetAllEntries(), + b.currentIpAddressesByHostMap.GetAllEntries(), + b.hostNames.GetAllEntries(), + b.allStartTopologyIpChanged, + b.allStartTopologyEndpointsRemoved, + b.allTopologyChanged, + }, + ) + } + + b.Delay() + } +} + +func (b *BlueGreenStatusMonitor) Delay() { + currentPanic := b.panicMode.Load() + currentBgIntervalRate := b.GetIntervalRate() + var delayMs int + var ok bool + if currentPanic { + delayMs, ok = b.statusCheckIntervalMap[driver_infrastructure.HIGH] + } else { + delayMs, ok = b.statusCheckIntervalMap[currentBgIntervalRate] + } + if !ok || delayMs == 0 { + delayMs = DEFAULT_CHECK_INTERVAL + } + + endTime := time.Now().Add(time.Millisecond * time.Duration(delayMs)) + minDelay := min(delayMs, 50) + + for b.GetIntervalRate() == currentBgIntervalRate && time.Now().Before(endTime) && !b.stop.Load() && b.panicMode.Load() == currentPanic { + time.Sleep(time.Duration(minDelay)) + } +} + +func (b *BlueGreenStatusMonitor) SetIntervalRate(intervalRate driver_infrastructure.BlueGreenIntervalRate) { + b.intervalRate.Store(int32(intervalRate)) +} + +func (b *BlueGreenStatusMonitor) GetIntervalRate() driver_infrastructure.BlueGreenIntervalRate { + return mapToBlueGreenIntervalRate(b.intervalRate.Load()) +} + +func mapToBlueGreenIntervalRate(value int32) driver_infrastructure.BlueGreenIntervalRate { + switch value { + case int32(driver_infrastructure.BASELINE): + return driver_infrastructure.BASELINE + case int32(driver_infrastructure.INCREASED): + return driver_infrastructure.INCREASED + case int32(driver_infrastructure.HIGH): + return driver_infrastructure.HIGH + default: + return driver_infrastructure.INVALID + } +} + +func (b *BlueGreenStatusMonitor) UpdateIpAddressFlags() { + if b.collectedTopology.Load() { + b.allStartTopologyIpChanged = false + b.allStartTopologyEndpointsRemoved = false + b.allTopologyChanged = false + return + } + + if !b.collectedIpAddresses.Load() { + haveAllChanged := len(b.startTopology) > 0 + for _, hostInfo := range b.startTopology { + host := hostInfo.Host + startIp, startIpExists := b.startIpAddressesByHostMap.Get(host) + currentIp, currentIpExists := b.currentIpAddressesByHostMap.Get(host) + if !startIpExists || !currentIpExists || startIp == currentIp { + // If any of the hosts in startTopology have not changed - return false. + haveAllChanged = false + break + } + } + b.allStartTopologyIpChanged = haveAllChanged + } + + haveAllChanged := len(b.startTopology) > 0 + for _, hostInfo := range b.startTopology { + host := hostInfo.Host + startIp, startIpExists := b.startIpAddressesByHostMap.Get(host) + currentIp, currentIpExists := b.currentIpAddressesByHostMap.Get(host) + if !startIpExists || startIp == "" || (currentIpExists && currentIp != "") { + // If any of the hosts in startTopology still has an IP address - return false. + haveAllChanged = false + break + } + } + b.allStartTopologyEndpointsRemoved = haveAllChanged + + if !b.collectedTopology.Load() { + currentTopologyCopy := b.GetCurrentTopology() + b.allTopologyChanged = len(currentTopologyCopy) > 0 && len(b.startTopology) > 0 && host_info_util.HaveNoHostsInCommon(currentTopologyCopy, b.startTopology) + } +} + +func (b *BlueGreenStatusMonitor) CollectHostIpAddresses() { + b.currentIpAddressesByHostMap.Clear() + + if b.hostNames.Size() == 0 { + return + } + + for host := range b.hostNames.GetAllEntries() { + b.currentIpAddressesByHostMap.PutIfAbsent(host, b.GetIpAddress(host)) + } + + if b.collectedIpAddresses.Load() { + b.startIpAddressesByHostMap.ReplaceCacheWithCopy(b.currentIpAddressesByHostMap) + } +} + +func (b *BlueGreenStatusMonitor) CollectTopology() error { + if b.hostListProvider == nil { + return nil + } + + conn := b.connection.Load() + if conn == nil || b.pluginService.GetTargetDriverDialect().IsClosed(*conn) { + return nil + } + hosts, err := b.hostListProvider.ForceRefresh(*conn) + if err != nil { + slog.Warn(err.Error()) + return err + } + b.currentTopology.Store(&hosts) + + currentTopologyCopy := b.GetCurrentTopology() + if b.collectedTopology.Load() && currentTopologyCopy != nil { + b.startTopology = currentTopologyCopy + for _, hostInfo := range currentTopologyCopy { + b.hostNames.Put(hostInfo.Host, true) + } + } + return nil +} + +func (b *BlueGreenStatusMonitor) CollectStatus() { + connPtr := b.connection.Load() + if connPtr == nil || b.pluginService.GetTargetDriverDialect().IsClosed(*connPtr) { + return + } + conn := *connPtr + + if !b.blueGreenDialect.IsBlueGreenStatusAvailable(conn) { + if !b.pluginService.GetTargetDriverDialect().IsClosed(conn) { + b.currentPhase = driver_infrastructure.NOT_CREATED + slog.Debug(error_util.GetMessage("BlueGreenDeployment.statusNotAvailable", b.role, driver_infrastructure.NOT_CREATED)) + } else { + b.connection.Store(nil) + b.currentPhase = driver_infrastructure.BlueGreenPhase{} + b.panicMode.Store(true) + } + return + } + + results := b.blueGreenDialect.GetBlueGreenStatus(conn) + + statusEntries := make([]StatusInfo, 0, len(results)) + for _, result := range results { + version := result.Version + if !slices.Contains(knownVersions, version) { + versionCopy := version + version = latestKnownVersion + slog.Warn(error_util.GetMessage("BlueGreenDeployment.unknownVersion", versionCopy)) + } + role := driver_infrastructure.ParseRole(result.Role) + if b.role != role { + continue + } + phase := driver_infrastructure.ParsePhase(result.Status) + statusEntries = append(statusEntries, StatusInfo{version, result.Endpoint, result.Port, phase, role}) + } + + statusInfo := utils.FilterSliceFindFirst(statusEntries, func(s StatusInfo) bool { + return utils.IsWriterClusterDns(s.endpoint) && utils.IsNotOldInstance(s.endpoint) + }) + + if !statusInfo.IsZero() { + // Cluster writer endpoint found. Add reader endpoint as well. + b.hostNames.Put(strings.Replace(strings.ToLower(statusInfo.endpoint), ".cluster-", ".cluster-ro-", 1), true) + } + + if statusInfo.IsZero() { + // Could be an instance endpoint. + statusInfo = utils.FilterSliceFindFirst(statusEntries, func(s StatusInfo) bool { + return utils.IsRdsInstance(s.endpoint) && utils.IsNotOldInstance(s.endpoint) + }) + } + + if statusInfo.IsZero() { + if len(statusEntries) == 0 { + // It's normal to expect that the status table has no entries after BGD is completed. + // Old1 cluster/instance has been separated and no longer receives updates from related green cluster/instance. + if b.role != driver_infrastructure.SOURCE { + slog.Warn(error_util.GetMessage("BlueGreenDeployment.noEntriesInStatusTable", b.role)) + } + b.currentPhase = driver_infrastructure.BlueGreenPhase{} + } + } else { + b.currentPhase = statusInfo.phase + b.version = statusInfo.version + b.port = statusInfo.port + } + + if b.collectedTopology.Load() { + for _, statusInfo := range statusEntries { + if statusInfo.endpoint != "" && utils.IsNotOldInstance(statusInfo.endpoint) { + b.hostNames.Put(strings.ToLower(statusInfo.endpoint), true) + } + } + } + + if !b.connectionHostInfoCorrect.Load() && !statusInfo.IsZero() { + // We connected to an initialHostInfo that might be not the desired Blue or Green cluster. + // We need to reconnect to a correct one. + statusInfoHostIpAddress := b.GetIpAddress(statusInfo.endpoint) + connectedIpAddressCopy := b.GetConnectedIpAddress() + if connectedIpAddressCopy != "" && connectedIpAddressCopy != statusInfoHostIpAddress { + // Found endpoint confirms that we're connected to a different host, and we need to reconnect. + reconnectionHostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost(statusInfo.endpoint).SetPort(statusInfo.port).Build() + b.connectionHostInfo.Store(reconnectionHostInfo) + b.connectionHostInfoCorrect.Store(true) + b.CloseConnection() + b.panicMode.Store(true) + } else { + // We're already connected to a correct host. + b.connectionHostInfoCorrect.Store(true) + b.panicMode.Store(false) + } + + if b.connectionHostInfoCorrect.Load() && b.hostListProvider == nil { + // A connection to a correct cluster (blue or green) is established. + // Let's initialize HostListProvider + b.InitHostListProvider() + } + } +} + +func (b *BlueGreenStatusMonitor) InitHostListProvider() { + if b.hostListProvider != nil || !b.connectionHostInfoCorrect.Load() { + return + } + + hostListProps := utils.CreateMapCopy(b.props) + + // Need to instantiate a separate HostListProvider with + // a special unique clusterId to avoid interference with other HostListProviders opened for this cluster. + // Blue and Green clusters are expected to have different clusterId. + clusterId := fmt.Sprintf("%s::%v::%s", b.bgId, b.role, BG_CLUSTER_ID) + hostListProps[property_util.CLUSTER_ID.Name] = clusterId + if hostInfo := b.connectionHostInfo.Load(); !hostInfo.IsNil() && b.connectionHostInfoCorrect.Load() { + hostListProps[property_util.PORT.Name] = strconv.Itoa(hostInfo.Port) + hostListProps[property_util.HOST.Name] = hostInfo.Host + } + slog.Debug(error_util.GetMessage("BlueGreenDeployment.createHostListProvider", b.role, clusterId)) + b.hostListProvider = b.pluginService.CreateHostListProvider(hostListProps) +} + +func (b *BlueGreenStatusMonitor) OpenConnection() { + conn := b.connection.Load() + if conn != nil && !b.pluginService.GetTargetDriverDialect().IsClosed(*conn) { + return + } + b.connection.Store(nil) + b.panicMode.Store(true) + + connectionHostInfoCopy := b.connectionHostInfo.Load() + connectedIpAddressCopy := b.GetConnectedIpAddress() + + if connectionHostInfoCopy == nil { + b.connectionHostInfo.Store(b.initialHostInfo) + connectionHostInfoCopy = b.initialHostInfo + b.connectedIpAddress.Store("") + connectedIpAddressCopy = "" + b.connectionHostInfoCorrect.Store(false) + } + + if b.useIpAddress.Load() && connectedIpAddressCopy != "" { + if connectionWithIpHostInfo, err := host_info_util.NewHostInfoBuilder().CopyFrom(connectionHostInfoCopy).SetHost(connectedIpAddressCopy).Build(); err == nil { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.openingConnectionWithIp", b.role, connectionHostInfoCopy.Host)) + connectWithIpProps := utils.CreateMapCopy(b.props) + connectWithIpProps[property_util.IAM_HOST.Name] = connectionHostInfoCopy.Host + if conn, err := b.pluginService.ForceConnect(connectionWithIpHostInfo, connectWithIpProps); err == nil { + b.connection.Store(&conn) + slog.Debug(error_util.GetMessage("BlueGreenDeployment.openedConnectionWithIp", b.role, connectionHostInfoCopy.Host)) + b.panicMode.Store(false) + return + } + } + } else { + if finalConnectionHostInfoCopy, err := host_info_util.NewHostInfoBuilder().CopyFrom(connectionHostInfoCopy).Build(); err == nil { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.openingConnection", b.role, finalConnectionHostInfoCopy.Host)) + connectedIpAddressCopy = b.GetIpAddress(connectionHostInfoCopy.Host) + if conn, err := b.pluginService.ForceConnect(finalConnectionHostInfoCopy, b.props); err == nil { + b.connection.Store(&conn) + slog.Debug(error_util.GetMessage("BlueGreenDeployment.openedConnection", b.role, finalConnectionHostInfoCopy.Host)) + b.connectedIpAddress.Store(connectedIpAddressCopy) + b.panicMode.Store(false) + return + } + } + } + // Can't open connection. + b.connection.Store(nil) + b.panicMode.Store(true) +} + +func (b *BlueGreenStatusMonitor) CloseConnection() { + conn := b.connection.Load() + b.connection.Store(nil) + + if conn != nil && !b.pluginService.GetTargetDriverDialect().IsClosed(*conn) { + _ = (*conn).Close() + } +} + +func (b *BlueGreenStatusMonitor) GetIpAddress(host string) string { + ips, err := net.LookupIP(host) + if err != nil || len(ips) == 0 { + return "" + } + + return ips[0].String() +} + +func (b *BlueGreenStatusMonitor) GetCurrentTopology() []*host_info_util.HostInfo { + topology := b.currentTopology.Load() + if topology == nil { + return []*host_info_util.HostInfo{} + } + return *topology +} + +func (b *BlueGreenStatusMonitor) GetConnectedIpAddress() (connectedIpAddressCopy string) { + if val := b.connectedIpAddress.Load(); val != nil { + if ip, ok := val.(string); ok { + connectedIpAddressCopy = ip + } + } + return +} + +func (b *BlueGreenStatusMonitor) ResetCollectedData() { + b.startIpAddressesByHostMap.Clear() + b.startTopology = make([]*host_info_util.HostInfo, 0) + b.hostNames.Clear() +} diff --git a/awssql/plugins/bg/bg_status_provider.go b/awssql/plugins/bg/bg_status_provider.go new file mode 100644 index 00000000..a0118d24 --- /dev/null +++ b/awssql/plugins/bg/bg_status_provider.go @@ -0,0 +1,977 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package bg + +import ( + "context" + "fmt" + "log/slog" + "reflect" + "slices" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" +) + +type BlueGreenProviderSupplier = func(pluginService driver_infrastructure.PluginService, props map[string]string, bgdId string) *BlueGreenStatusProvider + +type BlueGreenStatusProvider struct { + pluginService driver_infrastructure.PluginService + props map[string]string + bgdId string + statusCheckIntervalMap map[driver_infrastructure.BlueGreenIntervalRate]int + switchoverDuration time.Duration + postStatusEndTime time.Time + latestStatusPhase driver_infrastructure.BlueGreenPhase + summaryStatus driver_infrastructure.BlueGreenStatus + suspendNewBlueConnectionsWhenInProgress bool + rollback bool + blueDnsUpdateCompleted bool + greenDnsRemoved bool + greenTopologyChanged bool + monitors []*BlueGreenStatusMonitor + correspondingHosts *utils.RWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]] + roleByHost *utils.RWMap[driver_infrastructure.BlueGreenRole] + phaseTimeNano *utils.RWMap[PhaseTimeInfo] + iamHostSuccessfulConnects *utils.RWMap[[]string] + hostIpAddresses *utils.RWMap[string] + greenHostChangeNameTimes *utils.RWMap[time.Time] + interimStatuses []BlueGreenInterimStatus + interimStatusHashes []uint64 + lastContextHash uint64 + allGreenHostsChangedName atomic.Bool + processStatusLock sync.Mutex +} + +func NewBlueGreenStatusProvider(pluginService driver_infrastructure.PluginService, props map[string]string, bgId string) *BlueGreenStatusProvider { + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_INTERVAL_BASELINE_MS), + driver_infrastructure.INCREASED: property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_INTERVAL_INCREASED_MS), + driver_infrastructure.HIGH: property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_INTERVAL_HIGH_MS), + } + + switchoverTimeMs := property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_SWITCHOVER_TIMEOUT_MS) + + provider := &BlueGreenStatusProvider{ + pluginService: pluginService, + props: props, + bgdId: bgId, + statusCheckIntervalMap: statusCheckIntervalMap, + switchoverDuration: time.Millisecond * time.Duration(switchoverTimeMs), + latestStatusPhase: driver_infrastructure.NOT_CREATED, + monitors: []*BlueGreenStatusMonitor{nil, nil}, + correspondingHosts: utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + iamHostSuccessfulConnects: utils.NewRWMap[[]string](), + hostIpAddresses: utils.NewRWMap[string](), + greenHostChangeNameTimes: utils.NewRWMap[time.Time](), + roleByHost: utils.NewRWMap[driver_infrastructure.BlueGreenRole](), + phaseTimeNano: utils.NewRWMap[PhaseTimeInfo](), + interimStatuses: []BlueGreenInterimStatus{{}, {}}, + interimStatusHashes: []uint64{0, 0}, + lastContextHash: 0, + suspendNewBlueConnectionsWhenInProgress: property_util.GetVerifiedWrapperPropertyValue[bool](props, property_util.BG_SUSPEND_NEW_BLUE_CONNECTIONS), + } + + provider.allGreenHostsChangedName.Store(false) + dialect := pluginService.GetDialect() + if _, ok := dialect.(driver_infrastructure.BlueGreenDialect); ok { + provider.initMonitoring() + } else { + slog.Warn(error_util.GetMessage("BlueGreenDeployment.unsupportedDialect", bgId, reflect.TypeOf(dialect))) + } + return provider +} + +func (b *BlueGreenStatusProvider) ClearMonitors() { + for _, monitor := range b.monitors { + if monitor != nil { + monitor.stop.Store(true) + monitor.wg.Wait() + } + } +} + +func (b *BlueGreenStatusProvider) initMonitoring() { + currentHostInfo, _ := b.pluginService.GetCurrentHostInfo() + monitoringProps := b.GetMonitoringProperties() + b.monitors[driver_infrastructure.SOURCE.GetValue()] = NewBlueGreenStatusMonitor( + driver_infrastructure.SOURCE, + b.bgdId, + currentHostInfo, + b.pluginService, + monitoringProps, + b.statusCheckIntervalMap, + b.PrepareStatus) + + b.monitors[driver_infrastructure.TARGET.GetValue()] = NewBlueGreenStatusMonitor( + driver_infrastructure.TARGET, + b.bgdId, + currentHostInfo, + b.pluginService, + monitoringProps, + b.statusCheckIntervalMap, + b.PrepareStatus) +} + +func (b *BlueGreenStatusProvider) GetMonitoringProperties() map[string]string { + monitoringConnectionProps := utils.CreateMapCopy(b.props) + for propKey, propValue := range b.props { + if strings.HasPrefix(propKey, property_util.BG_PROPERTY_PREFIX) { + monitoringConnectionProps[strings.TrimPrefix(propKey, property_util.BG_PROPERTY_PREFIX)] = propValue + delete(monitoringConnectionProps, propKey) + } + } + return monitoringConnectionProps +} + +func (b *BlueGreenStatusProvider) PrepareStatus(role driver_infrastructure.BlueGreenRole, interimStatus BlueGreenInterimStatus) { + b.processStatusLock.Lock() + defer b.processStatusLock.Unlock() + + statusHash := interimStatus.GetCustomHashCode() + contextHash := b.getContextHash() + if b.interimStatusHashes[role.GetValue()] == statusHash && b.lastContextHash == contextHash { + // No changes. + return + } + + // Some changes detected. Update summary status. + slog.Debug(error_util.GetMessage("BlueGreenDeployment.interimStatus", b.bgdId, role, interimStatus)) + b.UpdatePhase(role, interimStatus) + + // Store interim status and corresponding hash. + b.interimStatuses[role.GetValue()] = interimStatus + b.interimStatusHashes[role.GetValue()] = statusHash + b.lastContextHash = contextHash + + // Update map of IP addresses. + for key, val := range interimStatus.startIpAddressesByHostMap { + b.hostIpAddresses.Put(key, val) + } + + // Update roleByHost based on provided host names. + for hostName := range interimStatus.hostNames { + b.roleByHost.Put(hostName, role) + } + + b.updateCorrespondingHosts() + err := b.UpdateSummaryStatus(role, interimStatus) + if err != nil { + slog.Warn(err.Error()) + return + } + + err = b.UpdateMonitors() + if err != nil { + slog.Warn(err.Error()) + return + } + b.updateStatusCache() + b.LogCurrentContext() + + // Log final switchover results. + b.LogSwitchoverFinalSummary() + + b.ResetContextWhenCompleted() +} + +func (b *BlueGreenStatusProvider) UpdatePhase(role driver_infrastructure.BlueGreenRole, interimStatus BlueGreenInterimStatus) { + roleStatus := b.interimStatuses[role.GetValue()] + latestInterimPhase := driver_infrastructure.NOT_CREATED + if !roleStatus.IsZero() { + latestInterimPhase = roleStatus.phase + } + + if interimStatus.IsZero() || interimStatus.phase.IsZero() { + return + } + + if !latestInterimPhase.IsZero() && interimStatus.phase.GetPhase() < latestInterimPhase.GetPhase() { + b.rollback = true + slog.Debug(error_util.GetMessage("BlueGreenDeployment.rollback", b.bgdId)) + } + + // Do not allow status to move backward unless it is a rollback. + shouldUpdate := false + if b.rollback { + shouldUpdate = interimStatus.phase.GetPhase() < b.latestStatusPhase.GetPhase() + } else { + shouldUpdate = interimStatus.phase.GetPhase() >= b.latestStatusPhase.GetPhase() + } + if shouldUpdate { + b.latestStatusPhase = interimStatus.phase + } +} + +func (b *BlueGreenStatusProvider) updateStatusCache() { + b.pluginService.SetBgStatus(b.summaryStatus, b.bgdId) + b.StorePhaseTime(b.summaryStatus.GetCurrentPhase()) +} + +func (b *BlueGreenStatusProvider) updateCorrespondingHosts() { + b.correspondingHosts.Clear() + + sourceInterimStatus := b.interimStatuses[driver_infrastructure.SOURCE.GetValue()] + targetInterimStatus := b.interimStatuses[driver_infrastructure.TARGET.GetValue()] + + if len(sourceInterimStatus.startTopology) > 0 && len(targetInterimStatus.startTopology) > 0 { + blueWriterHostInfo := b.GetWriterHost(driver_infrastructure.SOURCE) + greenWriterHostInfo := b.GetWriterHost(driver_infrastructure.TARGET) + sortedBlueReaderHostInfos := b.GetReaderHosts(driver_infrastructure.SOURCE) + sortedGreenReaderHostInfos := b.GetReaderHosts(driver_infrastructure.TARGET) + + if !blueWriterHostInfo.IsNil() { + b.correspondingHosts.Put(blueWriterHostInfo.GetHost(), utils.NewPair(blueWriterHostInfo, greenWriterHostInfo)) + } + + if len(sortedBlueReaderHostInfos) > 0 { + // Map blue readers to green hosts. + if len(sortedGreenReaderHostInfos) > 0 { + for index, blueHostInfo := range sortedBlueReaderHostInfos { + greenIndex := index % len(sortedGreenReaderHostInfos) + b.correspondingHosts.Put(blueHostInfo.Host, utils.NewPair(blueHostInfo, sortedGreenReaderHostInfos[greenIndex])) + } + } else { + // There's no green reader hosts. We have to map all blue reader hosts to the green writer. + for _, blueHostInfo := range sortedBlueReaderHostInfos { + b.correspondingHosts.Put(blueHostInfo.Host, utils.NewPair(blueHostInfo, greenWriterHostInfo)) + } + } + } + } + + if len(sourceInterimStatus.hostNames) > 0 && len(targetInterimStatus.hostNames) > 0 { + blueHosts := sourceInterimStatus.hostNames // Why is the writer not added previously? Missing a line? + greenHosts := targetInterimStatus.hostNames + + blueClusterHost := utils.FilterSetFindFirst(blueHosts, func(s string) bool { + return utils.IsWriterClusterDns(s) + }) + greenClusterHost := utils.FilterSetFindFirst(greenHosts, func(s string) bool { + return utils.IsWriterClusterDns(s) + }) + + if blueClusterHost != "" && greenClusterHost != "" { + blueHost, _ := host_info_util.NewHostInfoBuilder().SetHost(blueClusterHost).Build() + greenHost, _ := host_info_util.NewHostInfoBuilder().SetHost(greenClusterHost).Build() + b.correspondingHosts.PutIfAbsent(blueClusterHost, utils.NewPair(blueHost, greenHost)) + } + + blueClusterReaderHost := utils.FilterSetFindFirst(blueHosts, func(s string) bool { + return utils.IsReaderClusterDns(s) + }) + greenClusterReaderHost := utils.FilterSetFindFirst(greenHosts, func(s string) bool { + return utils.IsReaderClusterDns(s) + }) + + if blueClusterReaderHost != "" && greenClusterReaderHost != "" { + blueHost, _ := host_info_util.NewHostInfoBuilder().SetHost(blueClusterReaderHost).Build() + greenHost, _ := host_info_util.NewHostInfoBuilder().SetHost(greenClusterReaderHost).Build() + b.correspondingHosts.PutIfAbsent(blueClusterReaderHost, utils.NewPair(blueHost, greenHost)) + } + + for blueHost := range blueHosts { + if utils.IsRdsCustomClusterDns(blueHost) { + customClusterName := utils.GetRdsClusterId(blueHost) + if customClusterName != "" { + for greenHost := range greenHosts { + if utils.IsRdsCustomClusterDns(greenHost) && customClusterName == utils.RemoveGreenInstancePrefix(utils.GetRdsClusterId(greenHost)) { + blueHostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost(blueHost).Build() + greenHostInfo, _ := host_info_util.NewHostInfoBuilder().SetHost(greenHost).Build() + b.correspondingHosts.PutIfAbsent(blueHost, utils.NewPair(blueHostInfo, greenHostInfo)) + break + } + } + } + } + } + } +} + +func (b *BlueGreenStatusProvider) GetWriterHost(role driver_infrastructure.BlueGreenRole) *host_info_util.HostInfo { + interimStatus := b.interimStatuses[role.GetValue()] + if interimStatus.IsZero() { + return nil + } + return host_info_util.GetWriter(interimStatus.startTopology) +} + +func (b *BlueGreenStatusProvider) GetReaderHosts(role driver_infrastructure.BlueGreenRole) []*host_info_util.HostInfo { + interimStatus := b.interimStatuses[role.GetValue()] + if interimStatus.IsZero() { + return nil + } + return host_info_util.GetReaders(interimStatus.startTopology) +} + +func (b *BlueGreenStatusProvider) UpdateSummaryStatus(role driver_infrastructure.BlueGreenRole, interimStatus BlueGreenInterimStatus) error { + switch b.latestStatusPhase { + case driver_infrastructure.NOT_CREATED: + b.summaryStatus = driver_infrastructure.NewBgStatus(b.bgdId, driver_infrastructure.NOT_CREATED, nil, nil, nil, nil) + case driver_infrastructure.CREATED: + b.UpdateDnsFlags(role, interimStatus) + b.summaryStatus = b.GetStatusOfCreated() + case driver_infrastructure.PREPARATION: + b.StartSwitchoverTimer() + b.UpdateDnsFlags(role, interimStatus) + b.summaryStatus = b.GetStatusOfPreparation() + case driver_infrastructure.IN_PROGRESS: + b.UpdateDnsFlags(role, interimStatus) + b.summaryStatus = b.GetStatusOfInProgress() + case driver_infrastructure.POST: + b.UpdateDnsFlags(role, interimStatus) + b.summaryStatus = b.GetStatusOfPost() + case driver_infrastructure.COMPLETED: + b.UpdateDnsFlags(role, interimStatus) + b.summaryStatus = b.GetStatusOfCompleted() + default: + return error_util.NewGenericAwsWrapperError(error_util.GetMessage("BlueGreenDeployment.unknownPhase", b.bgdId, b.latestStatusPhase)) + } + return nil +} + +func (b *BlueGreenStatusProvider) UpdateMonitors() error { + switch b.summaryStatus.GetCurrentPhase() { + case driver_infrastructure.NOT_CREATED: + for _, monitor := range b.monitors { + monitor.SetIntervalRate(driver_infrastructure.BASELINE) + monitor.collectedIpAddresses.Store(false) + monitor.collectedTopology.Store(false) + monitor.useIpAddress.Store(false) + } + case driver_infrastructure.CREATED: + for _, monitor := range b.monitors { + monitor.SetIntervalRate(driver_infrastructure.INCREASED) + monitor.collectedIpAddresses.Store(true) + monitor.collectedTopology.Store(true) + monitor.useIpAddress.Store(false) + if b.rollback { + monitor.ResetCollectedData() + } + } + case driver_infrastructure.PREPARATION: + case driver_infrastructure.IN_PROGRESS: + case driver_infrastructure.POST: + for _, monitor := range b.monitors { + monitor.SetIntervalRate(driver_infrastructure.HIGH) + monitor.collectedIpAddresses.Store(false) + monitor.collectedTopology.Store(false) + monitor.useIpAddress.Store(true) + } + case driver_infrastructure.COMPLETED: + for _, monitor := range b.monitors { + monitor.SetIntervalRate(driver_infrastructure.BASELINE) + monitor.collectedIpAddresses.Store(false) + monitor.collectedTopology.Store(false) + monitor.useIpAddress.Store(false) + monitor.ResetCollectedData() + } + + if !b.rollback { + if sourceMonitor := b.monitors[driver_infrastructure.SOURCE.GetValue()]; sourceMonitor != nil { + sourceMonitor.stop.Store(true) + } + } + default: + return error_util.NewGenericAwsWrapperError(error_util.GetMessage("BlueGreenDeployment.unknownPhase", b.bgdId, b.latestStatusPhase)) + } + return nil +} + +func (b *BlueGreenStatusProvider) UpdateDnsFlags(role driver_infrastructure.BlueGreenRole, interimStatus BlueGreenInterimStatus) { + if role == driver_infrastructure.SOURCE && !b.blueDnsUpdateCompleted && interimStatus.allStartTopologyIpChanged { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.blueDnsCompleted", b.bgdId)) + b.blueDnsUpdateCompleted = true + b.StoreBlueDnsUpdateTime() + } + + if role == driver_infrastructure.TARGET && !b.greenDnsRemoved && interimStatus.allStartTopologyEndpointsRemoved { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.greenDnsRemoved", b.bgdId)) + b.greenDnsRemoved = true + b.StoreGreenDnsRemoveTime() + } + + if role == driver_infrastructure.TARGET && !b.greenTopologyChanged && interimStatus.allTopologyChanged { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.greenTopologyChanged", b.bgdId)) + b.greenTopologyChanged = true + b.StoreGreenTopologyChangeTime() + } +} + +func (b *BlueGreenStatusProvider) GetStatusOfCreated() driver_infrastructure.BlueGreenStatus { + return driver_infrastructure.NewBgStatus( + b.bgdId, + driver_infrastructure.CREATED, + nil, + nil, + b.roleByHost, + b.correspondingHosts, + ) +} + +func (b *BlueGreenStatusProvider) GetStatusOfPreparation() driver_infrastructure.BlueGreenStatus { + if b.IsSwitchoverTimerExpired() { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.switchoverTimeout")) + if b.rollback { + return b.GetStatusOfCreated() + } + return b.GetStatusOfCompleted() + } + connectRoutings := b.AddSubstituteBlueWithIpAddressConnectRouting() + return driver_infrastructure.NewBgStatus( + b.bgdId, + driver_infrastructure.PREPARATION, + connectRoutings, + nil, + b.roleByHost, + b.correspondingHosts, + ) +} + +func (b *BlueGreenStatusProvider) AddSubstituteBlueWithIpAddressConnectRouting() []driver_infrastructure.ConnectRouting { + connectRoutings := make([]driver_infrastructure.ConnectRouting, 0, b.roleByHost.Size()*2) + for host, role := range b.roleByHost.GetAllEntries() { + hostPair, ok := b.correspondingHosts.Get(host) + if !ok || role != driver_infrastructure.SOURCE { + continue + } + + blueHostInfo := hostPair.GetLeft() + blueIp, ok := b.hostIpAddresses.Get(blueHostInfo.Host) + blueIpHostInfo := blueHostInfo + if ok && blueIp != "" { + blueIpHostInfo, _ = host_info_util.NewHostInfoBuilder().CopyFrom(blueHostInfo).SetHost(blueIp).Build() + } + + connectRoutings = append(connectRoutings, NewSubstituteConnectRouting( + host, + role, + blueIpHostInfo, + []*host_info_util.HostInfo{blueHostInfo}, + nil, + )) + + interimStatus := b.interimStatuses[role.GetValue()] + if interimStatus.IsZero() { + continue + } + + connectRoutings = append(connectRoutings, NewSubstituteConnectRouting( + host_info_util.GetHostAndPort(host, interimStatus.port), + role, + blueIpHostInfo, + []*host_info_util.HostInfo{blueHostInfo}, + nil, + )) + } + return connectRoutings +} + +func (b *BlueGreenStatusProvider) GetStatusOfInProgress() driver_infrastructure.BlueGreenStatus { + if b.IsSwitchoverTimerExpired() { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.switchoverTimeout")) + if b.rollback { + return b.GetStatusOfCreated() + } + return b.GetStatusOfCompleted() + } + + var connectRoutings []driver_infrastructure.ConnectRouting + if b.suspendNewBlueConnectionsWhenInProgress { + // All blue and green connect calls should be suspended. + connectRoutings = []driver_infrastructure.ConnectRouting{NewSuspendConnectRouting("", driver_infrastructure.SOURCE, b.bgdId)} + } else { + // If we're not suspending new connections then, at least, we need to use IP addresses. + connectRoutings = b.AddSubstituteBlueWithIpAddressConnectRouting() + } + + connectRoutings = append(connectRoutings, NewSuspendConnectRouting("", driver_infrastructure.TARGET, b.bgdId)) + + // All connect calls with IP address that belongs to blue or green host should be suspended. + uniqueIpAddresses := b.getUniqueIpAddresses() + for ipAddress := range uniqueIpAddresses { + if b.suspendNewBlueConnectionsWhenInProgress { + interimStatus := b.interimStatuses[driver_infrastructure.SOURCE.GetValue()] + if interimStatusHasIpAddress(interimStatus, ipAddress) { + connectRoutings = b.appendSuspendConnectRouting(connectRoutings, ipAddress) + connectRoutings = b.appendSuspendConnectRouting(connectRoutings, host_info_util.GetHostAndPort(ipAddress, interimStatus.port)) + continue + } + } + // Try to confirm that the ipAddress belongs to one of the green hosts + interimStatus := b.interimStatuses[driver_infrastructure.TARGET.GetValue()] + if interimStatusHasIpAddress(interimStatus, ipAddress) { + connectRoutings = b.appendSuspendConnectRouting(connectRoutings, ipAddress) + connectRoutings = b.appendSuspendConnectRouting(connectRoutings, host_info_util.GetHostAndPort(ipAddress, interimStatus.port)) + } + } + + executeRoutings := []driver_infrastructure.ExecuteRouting{ + NewSuspendExecuteRouting("", driver_infrastructure.SOURCE, b.bgdId), + NewSuspendExecuteRouting("", driver_infrastructure.TARGET, b.bgdId), + } + + // All traffic through connections with IP addresses that belong to blue or green hosts should be on hold. + for ipAddress := range uniqueIpAddresses { + if b.suspendNewBlueConnectionsWhenInProgress { + interimStatus := b.interimStatuses[driver_infrastructure.SOURCE.GetValue()] + if interimStatusHasIpAddress(interimStatus, ipAddress) { + executeRoutings = b.appendSuspendExecuteRouting(executeRoutings, ipAddress) + executeRoutings = b.appendSuspendExecuteRouting(executeRoutings, host_info_util.GetHostAndPort(ipAddress, interimStatus.port)) + continue + } + } + // Try to confirm that the ipAddress belongs to one of the green hosts + interimStatus := b.interimStatuses[driver_infrastructure.TARGET.GetValue()] + if interimStatusHasIpAddress(interimStatus, ipAddress) { + executeRoutings = b.appendSuspendExecuteRouting(executeRoutings, ipAddress) + executeRoutings = b.appendSuspendExecuteRouting(executeRoutings, host_info_util.GetHostAndPort(ipAddress, interimStatus.port)) + continue + } + executeRoutings = b.appendSuspendExecuteRouting(executeRoutings, ipAddress) + } + + return driver_infrastructure.NewBgStatus( + b.bgdId, + driver_infrastructure.IN_PROGRESS, + connectRoutings, + executeRoutings, + b.roleByHost, + b.correspondingHosts, + ) +} + +func (b *BlueGreenStatusProvider) getUniqueIpAddresses() map[string]bool { + uniqueIpAddresses := make(map[string]bool, b.hostIpAddresses.Size()) + for _, ipAddress := range b.hostIpAddresses.GetAllEntries() { + if _, ok := uniqueIpAddresses[ipAddress]; ipAddress != "" && !ok { + uniqueIpAddresses[ipAddress] = true + } + } + return uniqueIpAddresses +} + +func (b *BlueGreenStatusProvider) appendSuspendConnectRouting(connectRoutings []driver_infrastructure.ConnectRouting, host string) []driver_infrastructure.ConnectRouting { + return append(connectRoutings, NewSuspendConnectRouting(host, driver_infrastructure.BlueGreenRole{}, b.bgdId)) +} + +func (b *BlueGreenStatusProvider) appendSuspendExecuteRouting(executeRoutings []driver_infrastructure.ExecuteRouting, host string) []driver_infrastructure.ExecuteRouting { + return append(executeRoutings, NewSuspendExecuteRouting(host, driver_infrastructure.BlueGreenRole{}, b.bgdId)) +} + +func (b *BlueGreenStatusProvider) GetStatusOfPost() driver_infrastructure.BlueGreenStatus { + if b.IsSwitchoverTimerExpired() { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.switchoverTimeout")) + if b.rollback { + return b.GetStatusOfCreated() + } + return b.GetStatusOfCompleted() + } + return driver_infrastructure.NewBgStatus( + b.bgdId, + driver_infrastructure.POST, + b.CreatePostRouting(), + []driver_infrastructure.ExecuteRouting{}, + b.roleByHost, + b.correspondingHosts, + ) +} + +func (b *BlueGreenStatusProvider) CreatePostRouting() (connectRoutings []driver_infrastructure.ConnectRouting) { + if b.blueDnsUpdateCompleted && b.allGreenHostsChangedName.Load() { + return + } + + for host, role := range b.roleByHost.GetAllEntries() { + if !b.blueHostInCorrespondingHosts(host, role) { + continue + } + hostPair, ok := b.correspondingHosts.Get(host) + greenHostInfo := hostPair.GetRight() + if !ok || greenHostInfo.IsNil() { + // A corresponding host is not found. We need to suspend this call. + connectRoutings = append(connectRoutings, NewSuspendUntilCorrespondingHostFoundConnectRouting( + host, + role, + b.bgdId, + )) + interimStatus := b.interimStatuses[role.GetValue()] + if !interimStatus.IsZero() { + connectRoutings = append(connectRoutings, NewSuspendUntilCorrespondingHostFoundConnectRouting( + host_info_util.GetHostAndPort(host, interimStatus.port), + role, + b.bgdId, + )) + } + } else { + greenHost := greenHostInfo.Host + greenIp, ok := b.hostIpAddresses.Get(greenHost) + var greenHostInfoWithIp *host_info_util.HostInfo + if ok && greenIp != "" { + greenHostInfoWithIp, _ = host_info_util.NewHostInfoBuilder().CopyFrom(greenHostInfo).SetHost(greenIp).Build() + } else { + greenHostInfoWithIp = greenHostInfo + } + blueHostInfo := hostPair.GetLeft() + var iamHosts []*host_info_util.HostInfo + if b.IsAlreadySuccessfullyConnected(greenHost, host) && !blueHostInfo.IsNil() { + // Green host has already changed its name, and it's not a new blue host (no prefixes). + iamHosts = []*host_info_util.HostInfo{blueHostInfo} + } else if !blueHostInfo.IsNil() { + // Green host hasn't yet changed its name, so we need to try both possible IAM host options. + iamHosts = []*host_info_util.HostInfo{greenHostInfo, blueHostInfo} + } else { + iamHosts = []*host_info_util.HostInfo{greenHostInfo} + } + var iamSuccessfulConnectNotify func(iamHost string) = nil + if utils.IsRdsInstance(host) { + iamSuccessfulConnectNotify = func(iamHost string) { + b.RegisterIamHost(greenHost, iamHost) + } + } + + connectRoutings = append(connectRoutings, NewSubstituteConnectRouting( + host, + role, + greenHostInfoWithIp, + iamHosts, + iamSuccessfulConnectNotify, + )) + interimStatus := b.interimStatuses[role.GetValue()] + if !interimStatus.IsZero() { + connectRoutings = append(connectRoutings, NewSubstituteConnectRouting( + host_info_util.GetHostAndPort(host, interimStatus.port), + role, + greenHostInfoWithIp, + iamHosts, + iamSuccessfulConnectNotify, + )) + } + } + } + + if !b.greenDnsRemoved { + // New connect calls to green endpoints should be rejected. + connectRoutings = append(connectRoutings, NewRejectConnectRouting("", driver_infrastructure.TARGET)) + } + return connectRoutings +} + +func (b *BlueGreenStatusProvider) GetStatusOfCompleted() driver_infrastructure.BlueGreenStatus { + if b.IsSwitchoverTimerExpired() { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.switchoverTimeout")) + if b.rollback { + return b.GetStatusOfCreated() + } + return driver_infrastructure.NewBgStatus( + b.bgdId, + driver_infrastructure.COMPLETED, + []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, + b.roleByHost, + b.correspondingHosts, + ) + } + // BGD reports that it's completed but DNS hasn't yet updated completely. + // Pretend that status isn't (yet) completed. + if !b.blueDnsUpdateCompleted || !b.greenDnsRemoved { + return b.GetStatusOfPost() + } + return driver_infrastructure.NewBgStatus( + b.bgdId, + driver_infrastructure.COMPLETED, + []driver_infrastructure.ConnectRouting{}, + []driver_infrastructure.ExecuteRouting{}, + b.roleByHost, + utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + ) +} + +func (b *BlueGreenStatusProvider) RegisterIamHost(connectHost string, iamHost string) { + differentHostNames := connectHost != "" && connectHost != iamHost + if differentHostNames && !b.IsAlreadySuccessfullyConnected(connectHost, iamHost) { + b.greenHostChangeNameTimes.Put(connectHost, time.Now()) + slog.Debug(error_util.GetMessage("BlueGreenDeployment.greenHostChangedName", connectHost, iamHost)) + } + + if successfulConnects, ok := b.iamHostSuccessfulConnects.Get(connectHost); ok { + b.iamHostSuccessfulConnects.Put(connectHost, append(successfulConnects, iamHost)) + } else { + b.iamHostSuccessfulConnects.Put(connectHost, []string{iamHost}) + } + + if differentHostNames { + // Check all IAM hosts have changed their names + allHostChangedNames := true + for key, val := range b.iamHostSuccessfulConnects.GetAllEntries() { + if len(val) > 0 && utils.FilterSliceFindFirst(val, func(s string) bool { + return s != key + }) == "" { + allHostChangedNames = false + break + } + } + if allHostChangedNames && !b.allGreenHostsChangedName.Load() { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.allGreenHostChangedName")) + b.allGreenHostsChangedName.Store(true) + b.StoreGreenHostChangeNameTime() + } + } +} + +func (b *BlueGreenStatusProvider) IsAlreadySuccessfullyConnected(connectHost string, iamHost string) bool { + successfulConnects, ok := b.iamHostSuccessfulConnects.Get(connectHost) + return ok && slices.Contains(successfulConnects, iamHost) +} + +func (b *BlueGreenStatusProvider) getContextHash() uint64 { + return getValueHash(getValueHash(1, strconv.FormatBool(b.allGreenHostsChangedName.Load())), strconv.Itoa(b.iamHostSuccessfulConnects.Size())) +} + +func (b *BlueGreenStatusProvider) StorePhaseTime(phase driver_infrastructure.BlueGreenPhase) { + if phase.IsZero() { + return + } + b.PutIfAbsentPhaseTime(phase.GetName(), phase) +} + +func (b *BlueGreenStatusProvider) PutIfAbsentPhaseTime(key string, phase driver_infrastructure.BlueGreenPhase) { + if b.rollback { + key += " (rollback)" + } + b.phaseTimeNano.PutIfAbsent(key, PhaseTimeInfo{ + time.Now(), + phase, + }) +} + +func (b *BlueGreenStatusProvider) StoreBlueDnsUpdateTime() { + b.PutIfAbsentPhaseTime("Blue DNS updated", driver_infrastructure.BlueGreenPhase{}) +} + +func (b *BlueGreenStatusProvider) StoreGreenDnsRemoveTime() { + b.PutIfAbsentPhaseTime("Green DNS removed", driver_infrastructure.BlueGreenPhase{}) +} + +func (b *BlueGreenStatusProvider) StoreGreenHostChangeNameTime() { + b.PutIfAbsentPhaseTime("Green host certificates changed", driver_infrastructure.BlueGreenPhase{}) +} + +func (b *BlueGreenStatusProvider) StoreGreenTopologyChangeTime() { + b.PutIfAbsentPhaseTime("Green topology changed", driver_infrastructure.BlueGreenPhase{}) +} + +func (b *BlueGreenStatusProvider) StartSwitchoverTimer() { + if b.postStatusEndTime.Equal(time.Time{}) { + b.postStatusEndTime = time.Now().Add(b.switchoverDuration) + } +} + +func (b *BlueGreenStatusProvider) IsSwitchoverTimerExpired() bool { + return !b.postStatusEndTime.Equal(time.Time{}) && b.postStatusEndTime.Before(time.Now()) +} + +func (b *BlueGreenStatusProvider) ResetContextWhenCompleted() { + switchoverCompleted := (!b.rollback && b.summaryStatus.GetCurrentPhase() == driver_infrastructure.COMPLETED) || + (b.rollback && b.summaryStatus.GetCurrentPhase() == driver_infrastructure.CREATED) + + hasActiveSwitchoverPhase := utils.FilterMapFindFirstValue(b.phaseTimeNano.GetAllEntries(), func(p PhaseTimeInfo) bool { + return !p.Phase.IsZero() && p.Phase.IsActiveSwitchoverOrCompleted() + }) != PhaseTimeInfo{} + + if switchoverCompleted && hasActiveSwitchoverPhase { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.resetContext")) + b.rollback, b.greenDnsRemoved, b.greenTopologyChanged = false, false, false + b.allGreenHostsChangedName.Store(false) + b.postStatusEndTime = time.Time{} + b.lastContextHash = 0 + b.interimStatusHashes = []uint64{0, 0} + b.interimStatuses = []BlueGreenInterimStatus{{}, {}} + b.latestStatusPhase = driver_infrastructure.NOT_CREATED + b.summaryStatus = driver_infrastructure.BlueGreenStatus{} // double check nil checks etc + b.phaseTimeNano.Clear() + b.hostIpAddresses.Clear() + b.correspondingHosts.Clear() + b.roleByHost.Clear() + b.iamHostSuccessfulConnects.Clear() + b.greenHostChangeNameTimes.Clear() + + if !b.rollback { + if targetMonitor := b.monitors[driver_infrastructure.TARGET.GetValue()]; targetMonitor != nil { + targetMonitor.stop.Store(true) + } + } + } +} + +func (b *BlueGreenStatusProvider) LogSwitchoverFinalSummary() { + switchoverCompleted := (!b.rollback && b.summaryStatus.GetCurrentPhase() == driver_infrastructure.COMPLETED) || + (b.rollback && b.summaryStatus.GetCurrentPhase() == driver_infrastructure.CREATED) + + hasActiveSwitchoverPhase := false + for _, phaseTime := range b.phaseTimeNano.GetAllEntries() { + if !phaseTime.Phase.IsZero() && phaseTime.Phase.IsActiveSwitchoverOrCompleted() { + hasActiveSwitchoverPhase = true + break + } + } + + if !switchoverCompleted || !hasActiveSwitchoverPhase { + return + } + + var timeZeroPhase driver_infrastructure.BlueGreenPhase + var timeZeroKey string + if b.rollback { + timeZeroPhase = driver_infrastructure.PREPARATION + timeZeroKey = timeZeroPhase.GetName() + " (rollback)" + } else { + timeZeroPhase = driver_infrastructure.IN_PROGRESS + timeZeroKey = timeZeroPhase.GetName() + } + + timeZero, hasTimeZero := b.phaseTimeNano.Get(timeZeroKey) + divider := "----------------------------------------------------------------------------------\n" + + // Create sorted slice of phase time entries + type phaseEntry struct { + key string + phaseTime PhaseTimeInfo + } + var entries []phaseEntry + for key, phaseTime := range b.phaseTimeNano.GetAllEntries() { + entries = append(entries, phaseEntry{key: key, phaseTime: phaseTime}) + } + + sort.Slice(entries, func(i, j int) bool { + return entries[i].phaseTime.Timestamp.Before(entries[j].phaseTime.Timestamp) + }) + + var logMessage strings.Builder + logMessage.WriteString(fmt.Sprintf("[bgdId: '%s']", b.bgdId)) + logMessage.WriteString("\n") + logMessage.WriteString(divider) + logMessage.WriteString(fmt.Sprintf("%-28s %21s %31s\n", "timestamp", "time offset (ms)", "event")) + logMessage.WriteString(divider) + + for _, entry := range entries { + timestampStr := entry.phaseTime.Timestamp.Format("2006-01-02T15:04:05.000Z") + var offsetStr string + if hasTimeZero { + offsetMs := entry.phaseTime.Timestamp.Sub(timeZero.Timestamp).Milliseconds() + offsetStr = fmt.Sprintf("%d ms", offsetMs) + } + logMessage.WriteString(fmt.Sprintf("%28s %18s %31s\n", timestampStr, offsetStr, entry.key)) + } + logMessage.WriteString(divider) + + slog.Info(logMessage.String()) +} + +func (b *BlueGreenStatusProvider) LogCurrentContext() { + if !slog.Default().Enabled(context.TODO(), slog.LevelDebug) { + // We can skip this log message if debug level is in effect + // and more detailed message is going to be printed few lines below. + var currentPhaseStr string + if b.summaryStatus.IsZero() || b.summaryStatus.GetCurrentPhase().IsZero() { + currentPhaseStr = "" + } else { + currentPhaseStr = b.summaryStatus.GetCurrentPhase().GetName() + } + slog.Info(fmt.Sprintf("[bgdId: '%s'] BG status: %s", b.bgdId, currentPhaseStr)) + } + + var summaryStatusStr string + if b.summaryStatus.IsZero() { + summaryStatusStr = "" + } else { + summaryStatusStr = b.summaryStatus.String() + } + slog.Debug(fmt.Sprintf("[bgdId: '%s'] Summary status:\n%s", b.bgdId, summaryStatusStr)) + + var correspondingHostsBuilder strings.Builder + correspondingHostsBuilder.WriteString("Corresponding hosts:\n") + for key, value := range b.correspondingHosts.GetAllEntries() { + var rightHostStr string + if value.GetRight().IsNil() { + rightHostStr = "" + } else { + rightHostStr = value.GetRight().GetHostAndPort() + } + correspondingHostsBuilder.WriteString(fmt.Sprintf(" %s -> %s\n", key, rightHostStr)) + } + slog.Debug(correspondingHostsBuilder.String()) + + var phaseTimesBuilder strings.Builder + phaseTimesBuilder.WriteString("Phase times:\n") + for key, value := range b.phaseTimeNano.GetAllEntries() { + phaseTimesBuilder.WriteString(fmt.Sprintf(" %s -> %s\n", key, value.Timestamp.Format("2006-01-02T15:04:05.000Z"))) + } + slog.Debug(phaseTimesBuilder.String()) + + var greenHostChangeTimesBuilder strings.Builder + greenHostChangeTimesBuilder.WriteString("Green host certificate change times:\n") + for key, value := range b.greenHostChangeNameTimes.GetAllEntries() { + greenHostChangeTimesBuilder.WriteString(fmt.Sprintf(" %s -> %s\n", key, value.Format("2006-01-02T15:04:05.000Z"))) + } + slog.Debug(greenHostChangeTimesBuilder.String()) + + var latestStatusPhaseName string + if b.latestStatusPhase.IsZero() { + latestStatusPhaseName = "" + } else { + latestStatusPhaseName = b.latestStatusPhase.GetName() + } + + slog.Debug(fmt.Sprintf("\n"+ + " latestStatusPhase: %s\n"+ + " blueDnsUpdateCompleted: %t\n"+ + " greenDnsRemoved: %t\n"+ + " greenHostChangedName: %t\n"+ + " greenTopologyChanged: %t", + latestStatusPhaseName, + b.blueDnsUpdateCompleted, + b.greenDnsRemoved, + b.allGreenHostsChangedName.Load(), + b.greenTopologyChanged)) +} + +func (b *BlueGreenStatusProvider) isZero() bool { + return b == nil || (b.pluginService == nil && b.props == nil && b.bgdId == "") +} + +func (b *BlueGreenStatusProvider) blueHostInCorrespondingHosts(host string, role driver_infrastructure.BlueGreenRole) bool { + return role == driver_infrastructure.SOURCE && utils.FilterSetFindFirst(b.correspondingHosts.GetAllEntries(), func(s string) bool { + return s == host + }) != "" +} + +func interimStatusHasIpAddress(interimStatus BlueGreenInterimStatus, ipAddress string) bool { + firstMatchToIpAddress := utils.FilterMapFindFirstValue(interimStatus.startIpAddressesByHostMap, func(s string) bool { + return s != "" && s == ipAddress + }) + return !interimStatus.IsZero() && firstMatchToIpAddress != "" +} diff --git a/awssql/plugins/bg/bg_test_helpers.go b/awssql/plugins/bg/bg_test_helpers.go new file mode 100644 index 00000000..1aec6b88 --- /dev/null +++ b/awssql/plugins/bg/bg_test_helpers.go @@ -0,0 +1,313 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package bg + +import ( + "database/sql/driver" + "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" +) + +type TestBlueGreenStatusMonitor struct { + *BlueGreenStatusMonitor +} + +func NewTestBlueGreenStatusMonitor(blueGreenRole driver_infrastructure.BlueGreenRole, bgdId string, hostInfo *host_info_util.HostInfo, + pluginService driver_infrastructure.PluginService, monitoringProps map[string]string, statusCheckIntervalMap map[driver_infrastructure.BlueGreenIntervalRate]int, + onBlueGreenStatusChangeFunc func(role driver_infrastructure.BlueGreenRole, interimStatus BlueGreenInterimStatus)) *TestBlueGreenStatusMonitor { + dialect, _ := pluginService.GetDialect().(driver_infrastructure.BlueGreenDialect) + monitor := BlueGreenStatusMonitor{ + role: blueGreenRole, + bgId: bgdId, + initialHostInfo: hostInfo, + pluginService: pluginService, + props: monitoringProps, + statusCheckIntervalMap: statusCheckIntervalMap, + onBlueGreenStatusChangeFunc: onBlueGreenStatusChangeFunc, + blueGreenDialect: dialect, + currentPhase: driver_infrastructure.NOT_CREATED, + version: "1.0", + port: -1, + startIpAddressesByHostMap: utils.NewRWMap[string](), + currentIpAddressesByHostMap: utils.NewRWMap[string](), + hostNames: utils.NewRWMap[bool](), + startTopology: []*host_info_util.HostInfo{}, + } + monitor.stop.Store(true) + monitor.panicMode.Store(true) + monitor.intervalRate.Store(int32(driver_infrastructure.BASELINE)) + return &TestBlueGreenStatusMonitor{&monitor} +} + +func (t *TestBlueGreenStatusMonitor) GetPanicMode() bool { + return t.panicMode.Load() +} + +func (t *TestBlueGreenStatusMonitor) SetPanicMode(val bool) { + if t.panicMode.Load() != val { + t.panicMode.Store(val) + } +} + +func (t *TestBlueGreenStatusMonitor) SetCollectedTopology(val bool) { + if t.collectedTopology.Load() != val { + t.collectedTopology.Store(val) + } +} + +func (t *TestBlueGreenStatusMonitor) SetCollectedIpAddresses(val bool) { + if t.collectedIpAddresses.Load() != val { + t.collectedIpAddresses.Store(val) + } +} + +func (t *TestBlueGreenStatusMonitor) SetConnectionHostInfoCorrect(val bool) { + if t.connectionHostInfoCorrect.Load() != val { + t.connectionHostInfoCorrect.Store(val) + } +} + +func (t *TestBlueGreenStatusMonitor) SetUseIpAddress(val bool) { + if t.useIpAddress.Load() != val { + t.useIpAddress.Store(val) + } +} + +func (t *TestBlueGreenStatusMonitor) SetStop(val bool) { + if t.stop.Load() != val { + t.stop.Store(val) + } +} + +func (t *TestBlueGreenStatusMonitor) GetConnection() *driver.Conn { + return t.connection.Load() +} + +func (t *TestBlueGreenStatusMonitor) SetConnection(val *driver.Conn) { + if t.connection.Load() != val { + t.connection.Store(val) + } +} + +func (t *TestBlueGreenStatusMonitor) GetHostListProvider() driver_infrastructure.HostListProvider { + return t.hostListProvider +} + +func (t *TestBlueGreenStatusMonitor) SetHostListProvider(val driver_infrastructure.HostListProvider) { + t.hostListProvider = val +} + +func (t *TestBlueGreenStatusMonitor) SetAllStartTopologyIpChanged(val bool) { + t.allStartTopologyIpChanged = val +} + +func (t *TestBlueGreenStatusMonitor) GetAllStartTopologyIpChanged() bool { + return t.allStartTopologyIpChanged +} + +func (t *TestBlueGreenStatusMonitor) SetAllStartTopologyEndpointsRemoved(val bool) { + t.allStartTopologyEndpointsRemoved = val +} + +func (t *TestBlueGreenStatusMonitor) GetAllStartTopologyEndpointsRemoved() bool { + return t.allStartTopologyEndpointsRemoved +} + +func (t *TestBlueGreenStatusMonitor) SetAllTopologyChanged(val bool) { + t.allTopologyChanged = val +} + +func (t *TestBlueGreenStatusMonitor) GetAllTopologyChanged() bool { + return t.allTopologyChanged +} + +func (t *TestBlueGreenStatusMonitor) GetStartIpAddressesByHostMap() *utils.RWMap[string] { + return t.startIpAddressesByHostMap +} + +func (t *TestBlueGreenStatusMonitor) GetCurrentIpAddressesByHostMap() *utils.RWMap[string] { + return t.currentIpAddressesByHostMap +} + +func (t *TestBlueGreenStatusMonitor) GetHostNames() *utils.RWMap[bool] { + return t.hostNames +} + +func (t *TestBlueGreenStatusMonitor) GetCurrentPhase() driver_infrastructure.BlueGreenPhase { + return t.currentPhase +} + +func (t *TestBlueGreenStatusMonitor) SetStartTopology(val []*host_info_util.HostInfo) { + t.startTopology = val +} + +func (t *TestBlueGreenStatusMonitor) SetCurrentTopology(val *[]*host_info_util.HostInfo) { + if t.currentTopology.Load() != val { + t.currentTopology.Store(val) + } +} + +func (t *TestBlueGreenStatusMonitor) SetConnectedIpAddress(val string) { + if t.connectedIpAddress.Load() != val { + t.connectedIpAddress.Store(val) + } +} + +func (t *TestBlueGreenStatusMonitor) SetConnectionHostInfo(val *host_info_util.HostInfo) { + if t.connectionHostInfo.Load() != val { + t.connectionHostInfo.Store(val) + } +} + +type TestBlueGreenStatusProvider struct { + *BlueGreenStatusProvider +} + +func NewTestBlueGreenStatusProvider(pluginService driver_infrastructure.PluginService, props map[string]string, bgId string) *TestBlueGreenStatusProvider { + statusCheckIntervalMap := map[driver_infrastructure.BlueGreenIntervalRate]int{ + driver_infrastructure.BASELINE: property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_INTERVAL_BASELINE_MS), + driver_infrastructure.INCREASED: property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_INTERVAL_INCREASED_MS), + driver_infrastructure.HIGH: property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_INTERVAL_HIGH_MS), + } + + switchoverTimeMs := property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_SWITCHOVER_TIMEOUT_MS) + + provider := &BlueGreenStatusProvider{ + pluginService: pluginService, + props: props, + bgdId: bgId, + statusCheckIntervalMap: statusCheckIntervalMap, + switchoverDuration: time.Millisecond * time.Duration(switchoverTimeMs), + latestStatusPhase: driver_infrastructure.NOT_CREATED, + monitors: []*BlueGreenStatusMonitor{nil, nil}, + correspondingHosts: utils.NewRWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]](), + iamHostSuccessfulConnects: utils.NewRWMap[[]string](), + hostIpAddresses: utils.NewRWMap[string](), + greenHostChangeNameTimes: utils.NewRWMap[time.Time](), + roleByHost: utils.NewRWMap[driver_infrastructure.BlueGreenRole](), + phaseTimeNano: utils.NewRWMap[PhaseTimeInfo](), + interimStatuses: []BlueGreenInterimStatus{{}, {}}, + interimStatusHashes: []uint64{0, 0}, + lastContextHash: 0, + suspendNewBlueConnectionsWhenInProgress: property_util.GetVerifiedWrapperPropertyValue[bool](props, property_util.BG_SUSPEND_NEW_BLUE_CONNECTIONS), + } + + provider.allGreenHostsChangedName.Store(false) + return &TestBlueGreenStatusProvider{provider} +} + +func (t *TestBlueGreenStatusProvider) GetAllGreenHostsChangedName() bool { + return t.allGreenHostsChangedName.Load() +} + +func (t *TestBlueGreenStatusProvider) SetAllGreenHostsChangedName(val bool) { + if t.allGreenHostsChangedName.Load() != val { + t.allGreenHostsChangedName.Store(val) + } +} + +func (t *TestBlueGreenStatusProvider) GetLatestStatusPhase() driver_infrastructure.BlueGreenPhase { + return t.latestStatusPhase +} + +func (t *TestBlueGreenStatusProvider) GetCorrespondingHosts() *utils.RWMap[utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo]] { + return t.correspondingHosts +} + +func (t *TestBlueGreenStatusProvider) GetRoleByHost() *utils.RWMap[driver_infrastructure.BlueGreenRole] { + return t.roleByHost +} + +func (t *TestBlueGreenStatusProvider) GetPhaseTimeNano() *utils.RWMap[PhaseTimeInfo] { + return t.phaseTimeNano +} + +func (t *TestBlueGreenStatusProvider) GetHostIpAddresses() *utils.RWMap[string] { + return t.hostIpAddresses +} + +func (t *TestBlueGreenStatusProvider) GetInterimStatuses() []BlueGreenInterimStatus { + return t.interimStatuses +} + +func (t *TestBlueGreenStatusProvider) GetRollback() bool { + return t.rollback +} + +func (t *TestBlueGreenStatusProvider) SetRollback(val bool) { + t.rollback = val +} + +func (t *TestBlueGreenStatusProvider) SetGreenTopologyChanged(val bool) { + t.greenTopologyChanged = val +} + +func (t *TestBlueGreenStatusProvider) GetGreenTopologyChanged() bool { + return t.greenTopologyChanged +} + +func (t *TestBlueGreenStatusProvider) SetGreenDnsRemoved(val bool) { + t.greenDnsRemoved = val +} + +func (t *TestBlueGreenStatusProvider) GetGreenDnsRemoved() bool { + return t.greenDnsRemoved +} + +func (t *TestBlueGreenStatusProvider) SetBlueDnsUpdateCompleted(val bool) { + t.blueDnsUpdateCompleted = val +} + +func (t *TestBlueGreenStatusProvider) GetBlueDnsUpdateCompleted() bool { + return t.blueDnsUpdateCompleted +} + +func (t *TestBlueGreenStatusProvider) SetSummaryStatus(val driver_infrastructure.BlueGreenStatus) { + t.summaryStatus = val +} + +func (t *TestBlueGreenStatusProvider) SetPostStatusEndTime(val time.Time) { + t.postStatusEndTime = val +} + +func (t *TestBlueGreenStatusProvider) GetPostStatusEndTime() time.Time { + return t.postStatusEndTime +} + +func NewTestBlueGreenInterimStatus(phase driver_infrastructure.BlueGreenPhase, startTopology []*host_info_util.HostInfo, + startIpAddressesByHostMap map[string]string, ipChanged bool, endpointsRemoved bool, allChanged bool) BlueGreenInterimStatus { + return BlueGreenInterimStatus{ + phase: phase, + version: "1.0", + port: 1234, + startTopology: startTopology, + startIpAddressesByHostMap: startIpAddressesByHostMap, + allStartTopologyIpChanged: ipChanged, + allStartTopologyEndpointsRemoved: endpointsRemoved, + allTopologyChanged: allChanged, + } +} + +func NewTestStatusInfo() StatusInfo { + return StatusInfo{ + version: "1.0", + } +} diff --git a/awssql/plugins/bg/connect_routing.go b/awssql/plugins/bg/connect_routing.go new file mode 100644 index 00000000..0df0b78e --- /dev/null +++ b/awssql/plugins/bg/connect_routing.go @@ -0,0 +1,239 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package bg + +import ( + "database/sql/driver" + "fmt" + "log/slog" + "strconv" + "strings" + "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/host_info_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" +) + +type RejectConnectRouting struct { + BaseRouting +} + +func (r *RejectConnectRouting) Apply(_ driver_infrastructure.ConnectionPlugin, _ *host_info_util.HostInfo, _ map[string]string, + _ bool, _ driver_infrastructure.PluginService) (driver.Conn, error) { + message := error_util.GetMessage("BlueGreenDeployment.inProgressCantConnect") + slog.Debug(message) + return nil, error_util.NewGenericAwsWrapperError(message) +} + +func NewRejectConnectRouting(hostAndPort string, role driver_infrastructure.BlueGreenRole) *RejectConnectRouting { + return &RejectConnectRouting{NewBaseRouting(hostAndPort, role)} +} + +type IamSuccessfulConnectFunc = func(string) + +type SubstituteConnectRouting struct { + substituteHostInfo *host_info_util.HostInfo + iamHosts []*host_info_util.HostInfo + iamSuccessfulConnectNotify IamSuccessfulConnectFunc + BaseRouting +} + +func (r *SubstituteConnectRouting) Apply(plugin driver_infrastructure.ConnectionPlugin, _ *host_info_util.HostInfo, props map[string]string, + _ bool, pluginService driver_infrastructure.PluginService) (driver.Conn, error) { + if utils.IsIP(r.substituteHostInfo.GetHost()) { + return pluginService.Connect(r.substituteHostInfo, props, plugin) + } + + iamInUse := pluginService.IsPluginInUse(driver_infrastructure.IAM_PLUGIN_CODE) + if !iamInUse { + return pluginService.Connect(r.substituteHostInfo, props, plugin) + } + + if len(r.iamHosts) == 0 { + return nil, error_util.NewGenericAwsWrapperError(error_util.GetMessage("BlueGreenDeployment.requireIamHost")) + } + + for _, iamHost := range r.iamHosts { + if iamHost == nil { + // Skip nil entries. + continue + } + var reroutedHostInfo *host_info_util.HostInfo + if r.substituteHostInfo.GetHost() == "" { + reroutedHostInfo, _ = host_info_util.NewHostInfoBuilder().SetHost(iamHost.Host).SetHostId(iamHost.HostId).SetAvailability(host_info_util.AVAILABLE).Build() + } else { + reroutedHostInfo, _ = host_info_util.NewHostInfoBuilder().CopyFrom(r.substituteHostInfo).SetHostId(iamHost.HostId).SetAvailability(host_info_util.AVAILABLE).Build() + reroutedHostInfo.AddAlias(iamHost.GetHost()) + } + rerouteProps := utils.CreateMapCopy(props) + rerouteProps[property_util.IAM_HOST.Name] = iamHost.GetHost() + if iamHost.IsPortSpecified() { + rerouteProps[property_util.IAM_DEFAULT_PORT.Name] = strconv.Itoa(iamHost.Port) + } + + conn, err := pluginService.Connect(reroutedHostInfo, rerouteProps, nil) + if err == nil { + if r.iamSuccessfulConnectNotify != nil { + r.iamSuccessfulConnectNotify(iamHost.GetHost()) + } + return conn, err + } + } + return nil, error_util.NewGenericAwsWrapperError(error_util.GetMessage("BlueGreenDeployment.inProgressCantOpenConnection", r.substituteHostInfo.GetHostAndPort())) +} + +func (r *SubstituteConnectRouting) String() string { + hostAndPort := "" + if r.hostAndPort != "" { + hostAndPort = r.hostAndPort + } + + role := "" + if !r.role.IsZero() { + role = r.role.String() + } + + substituteHostInfo := "" + if !r.substituteHostInfo.IsNil() { + substituteHostInfo = r.substituteHostInfo.GetHostAndPort() + } + + iamHosts := "" + if len(r.iamHosts) > 0 { + hostPorts := make([]string, len(r.iamHosts)) + for i, host := range r.iamHosts { + hostPorts[i] = host.GetHostAndPort() + } + iamHosts = strings.Join(hostPorts, ", ") + } + + return fmt.Sprintf("%s [%s, %s, substitute: %s, iamHosts: %s]", + "SubstituteConnectRouting", + hostAndPort, + role, + substituteHostInfo, + iamHosts, + ) +} + +func NewSubstituteConnectRouting(hostAndPort string, role driver_infrastructure.BlueGreenRole, substituteHostInfo *host_info_util.HostInfo, + iamHosts []*host_info_util.HostInfo, iamSuccessfulConnectNotify IamSuccessfulConnectFunc) *SubstituteConnectRouting { + return &SubstituteConnectRouting{substituteHostInfo, iamHosts, iamSuccessfulConnectNotify, NewBaseRouting(hostAndPort, role)} +} + +type SuspendConnectRouting struct { + bgId string + BaseRouting +} + +func (r *SuspendConnectRouting) Apply(_ driver_infrastructure.ConnectionPlugin, _ *host_info_util.HostInfo, props map[string]string, + _ bool, pluginService driver_infrastructure.PluginService) (driver.Conn, error) { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.inProgressSuspendConnect")) + parentCtx := pluginService.GetTelemetryContext() + telemetryFactory := pluginService.GetTelemetryFactory() + telemetryCtx, ctx := telemetryFactory.OpenTelemetryContext(TELEMETRY_SWITCHOVER, telemetry.NESTED, parentCtx) + + pluginService.SetTelemetryContext(ctx) + defer func() { + telemetryCtx.CloseContext() + pluginService.SetTelemetryContext(parentCtx) + }() + + bgStatus, ok := pluginService.GetBgStatus(r.bgId) + + timeoutMs := property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_CONNECT_TIMEOUT_MS) + holdStartTime := time.Now() + endTime := holdStartTime.Add(time.Millisecond * time.Duration(timeoutMs)) + + for time.Now().Before(endTime) && ok && !bgStatus.IsZero() && bgStatus.GetCurrentPhase() == driver_infrastructure.IN_PROGRESS { + r.Delay(SLEEP_TIME_DURATION, bgStatus, pluginService, r.bgId) + } + + bgStatus, ok = pluginService.GetBgStatus(r.bgId) + + if ok && !bgStatus.IsZero() && bgStatus.GetCurrentPhase() == driver_infrastructure.IN_PROGRESS { + return nil, error_util.NewGenericAwsWrapperError(error_util.GetMessage("BlueGreenDeployment.inProgressTryConnectLater", timeoutMs)) + } + message := error_util.GetMessage("BlueGreenDeployment.switchoverCompleteContinueWithConnect", time.Since(holdStartTime)) + slog.Debug(message) + + return nil, error_util.NewGenericAwsWrapperError(message) +} + +func NewSuspendConnectRouting(hostAndPort string, role driver_infrastructure.BlueGreenRole, bgId string) *SuspendConnectRouting { + return &SuspendConnectRouting{bgId, NewBaseRouting(hostAndPort, role)} +} + +type SuspendUntilCorrespondingHostFoundConnectRouting struct { + bgId string + BaseRouting +} + +func (r *SuspendUntilCorrespondingHostFoundConnectRouting) Apply(_ driver_infrastructure.ConnectionPlugin, hostInfo *host_info_util.HostInfo, props map[string]string, + _ bool, pluginService driver_infrastructure.PluginService) (driver.Conn, error) { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.waitConnectUntilCorrespondingHostFound", hostInfo.GetHost())) + parentCtx := pluginService.GetTelemetryContext() + telemetryFactory := pluginService.GetTelemetryFactory() + telemetryCtx, ctx := telemetryFactory.OpenTelemetryContext(TELEMETRY_SWITCHOVER, telemetry.NESTED, parentCtx) + + pluginService.SetTelemetryContext(ctx) + defer func() { + telemetryCtx.CloseContext() + pluginService.SetTelemetryContext(parentCtx) + }() + + bgStatus, ok := pluginService.GetBgStatus(r.bgId) + var correspondingPair utils.Pair[*host_info_util.HostInfo, *host_info_util.HostInfo] + if ok && !bgStatus.IsZero() { + correspondingPair = bgStatus.GetCorrespondingHosts()[hostInfo.Host] + } + + timeoutMs := property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_CONNECT_TIMEOUT_MS) + holdStartTime := time.Now() + endTime := holdStartTime.Add(time.Millisecond * time.Duration(timeoutMs)) + + for time.Now().Before(endTime) && ok && !bgStatus.IsZero() && bgStatus.GetCurrentPhase() != driver_infrastructure.COMPLETED && + correspondingPair.GetRight().IsNil() { + r.Delay(SLEEP_TIME_DURATION, bgStatus, pluginService, r.bgId) + bgStatus, ok = pluginService.GetBgStatus(r.bgId) + if ok && !bgStatus.IsZero() { + correspondingPair = bgStatus.GetCorrespondingHosts()[hostInfo.Host] + } + } + + if bgStatus.IsZero() || bgStatus.GetCurrentPhase() == driver_infrastructure.COMPLETED { + message := error_util.GetMessage("BlueGreenDeployment.completedContinueWithConnect", time.Since(holdStartTime)) + slog.Debug(message) + return nil, error_util.NewGenericAwsWrapperError(message) + } else if time.Now().After(endTime) { + return nil, error_util.NewGenericAwsWrapperError(error_util.GetMessage("BlueGreenDeployment.correspondingHostNotFoundTryConnectLater", hostInfo.GetHost(), timeoutMs)) + } + + message := error_util.GetMessage("BlueGreenDeployment.correspondingHostFoundContinueWithConnect", hostInfo.GetHost(), time.Since(holdStartTime)) + slog.Debug(message) + + return nil, nil +} + +func NewSuspendUntilCorrespondingHostFoundConnectRouting(hostAndPort string, role driver_infrastructure.BlueGreenRole, + bgId string) *SuspendUntilCorrespondingHostFoundConnectRouting { + return &SuspendUntilCorrespondingHostFoundConnectRouting{bgId, NewBaseRouting(hostAndPort, role)} +} diff --git a/awssql/plugins/bg/execute_routing.go b/awssql/plugins/bg/execute_routing.go new file mode 100644 index 00000000..99b33da0 --- /dev/null +++ b/awssql/plugins/bg/execute_routing.go @@ -0,0 +1,71 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package bg + +import ( + "log/slog" + "time" + + "github.com/aws/aws-advanced-go-wrapper/awssql/driver_infrastructure" + "github.com/aws/aws-advanced-go-wrapper/awssql/error_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/property_util" + "github.com/aws/aws-advanced-go-wrapper/awssql/utils/telemetry" +) + +type SuspendExecuteRouting struct { + bgId string + BaseRouting +} + +func (r *SuspendExecuteRouting) Apply(_ driver_infrastructure.ConnectionPlugin, props map[string]string, + pluginService driver_infrastructure.PluginService, methodName string, _ driver_infrastructure.ExecuteFunc, + _ ...any) driver_infrastructure.RoutingResultHolder { + slog.Debug(error_util.GetMessage("BlueGreenDeployment.inProgressSuspendMethod", methodName)) + parentCtx := pluginService.GetTelemetryContext() + telemetryFactory := pluginService.GetTelemetryFactory() + telemetryCtx, ctx := telemetryFactory.OpenTelemetryContext(TELEMETRY_SWITCHOVER, telemetry.NESTED, parentCtx) + + pluginService.SetTelemetryContext(ctx) + defer func() { + telemetryCtx.CloseContext() + pluginService.SetTelemetryContext(parentCtx) + }() + + bgStatus, ok := pluginService.GetBgStatus(r.bgId) + + timeoutMs := property_util.GetVerifiedWrapperPropertyValue[int](props, property_util.BG_CONNECT_TIMEOUT_MS) + holdStartTime := time.Now() + endTime := holdStartTime.Add(time.Millisecond * time.Duration(timeoutMs)) + + for time.Now().Before(endTime) && ok && !bgStatus.IsZero() && bgStatus.GetCurrentPhase() == driver_infrastructure.IN_PROGRESS { + r.Delay(SLEEP_TIME_DURATION, bgStatus, pluginService, r.bgId) + } + + bgStatus, ok = pluginService.GetBgStatus(r.bgId) + + if ok && !bgStatus.IsZero() && bgStatus.GetCurrentPhase() == driver_infrastructure.IN_PROGRESS { + return driver_infrastructure.RoutingResultHolder{WrappedErr: error_util.NewGenericAwsWrapperError( + error_util.GetMessage("BlueGreenDeployment.inProgressTryMethodLater", timeoutMs, methodName))} + } + slog.Debug(error_util.GetMessage("BlueGreenDeployment.switchoverCompletedContinueWithMethod", methodName, time.Since(holdStartTime))) + + return driver_infrastructure.EMPTY_ROUTING_RESULT_HOLDER +} + +func NewSuspendExecuteRouting(hostAndPort string, role driver_infrastructure.BlueGreenRole, bgId string) *SuspendExecuteRouting { + return &SuspendExecuteRouting{bgId, NewBaseRouting(hostAndPort, role)} +} diff --git a/awssql/plugins/default_plugin.go b/awssql/plugins/default_plugin.go index c58126ac..8ec958f5 100644 --- a/awssql/plugins/default_plugin.go +++ b/awssql/plugins/default_plugin.go @@ -34,8 +34,11 @@ type DefaultPlugin struct { ConnProviderManager driver_infrastructure.ConnectionProviderManager } +func (d *DefaultPlugin) GetPluginCode() string { + return "default" // plugin code is not used +} + func (d *DefaultPlugin) InitHostProvider( - initialUrl string, props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, initHostProviderFunc func() error) error { diff --git a/awssql/plugins/efm/host_monitoring_plugin.go b/awssql/plugins/efm/host_monitoring_plugin.go index 130acee3..7213479b 100644 --- a/awssql/plugins/efm/host_monitoring_plugin.go +++ b/awssql/plugins/efm/host_monitoring_plugin.go @@ -82,6 +82,10 @@ type HostMonitorConnectionPlugin struct { plugins.BaseConnectionPlugin } +func (b *HostMonitorConnectionPlugin) GetPluginCode() string { + return driver_infrastructure.EFM_PLUGIN_CODE +} + func (b *HostMonitorConnectionPlugin) GetSubscribedMethods() []string { return []string{"*"} } diff --git a/awssql/plugins/efm/monitor.go b/awssql/plugins/efm/monitor.go index df2893e0..dac97e84 100644 --- a/awssql/plugins/efm/monitor.go +++ b/awssql/plugins/efm/monitor.go @@ -51,7 +51,7 @@ func NewMonitorImpl( failureDetectionCount int, abortedConnectionsCounter telemetry.TelemetryCounter, ) *MonitorImpl { - monitoringConnectionProps := props + monitoringConnectionProps := utils.CreateMapCopy(props) for propKey, propValue := range props { if strings.HasPrefix(propKey, property_util.MONITORING_PROPERTY_PREFIX) { monitoringConnectionProps[strings.TrimPrefix(propKey, property_util.MONITORING_PROPERTY_PREFIX)] = propValue diff --git a/awssql/plugins/efm/monitor_service.go b/awssql/plugins/efm/monitor_service.go index 322404bd..555f05e7 100644 --- a/awssql/plugins/efm/monitor_service.go +++ b/awssql/plugins/efm/monitor_service.go @@ -46,7 +46,7 @@ type MonitorServiceImpl struct { func NewMonitorServiceImpl(pluginService driver_infrastructure.PluginService) (*MonitorServiceImpl, error) { if EFM_MONITORS == nil { - EFM_MONITORS = utils.NewSlidingExpirationCache[Monitor]( + EFM_MONITORS = utils.NewSlidingExpirationCache( "efm_monitors", func(monitor Monitor) bool { monitor.Close() diff --git a/awssql/plugins/execution_time_plugin.go b/awssql/plugins/execution_time_plugin.go index fcc6722b..83bfe997 100644 --- a/awssql/plugins/execution_time_plugin.go +++ b/awssql/plugins/execution_time_plugin.go @@ -50,6 +50,10 @@ func NewExecutionTimePlugin(pluginService driver_infrastructure.PluginService, return &ExecutionTimePlugin{}, nil } +func (d *ExecutionTimePlugin) GetPluginCode() string { + return driver_infrastructure.EXECUTION_TIME_PLUGIN_CODE +} + func (d *ExecutionTimePlugin) GetSubscribedMethods() []string { return []string{plugin_helpers.ALL_METHODS} } diff --git a/awssql/plugins/failover_plugin.go b/awssql/plugins/failover_plugin.go index 30f0bcc4..1581dffc 100644 --- a/awssql/plugins/failover_plugin.go +++ b/awssql/plugins/failover_plugin.go @@ -146,6 +146,10 @@ func NewFailoverPlugin(pluginService driver_infrastructure.PluginService, props }, nil } +func (p *FailoverPlugin) GetPluginCode() string { + return driver_infrastructure.FAILOVER_PLUGIN_CODE +} + func (p *FailoverPlugin) GetSubscribedMethods() []string { return append([]string{ plugin_helpers.CONNECT_METHOD, @@ -154,7 +158,6 @@ func (p *FailoverPlugin) GetSubscribedMethods() []string { } func (p *FailoverPlugin) InitHostProvider( - initialUrl string, props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, initHostProviderFunc func() error) error { diff --git a/awssql/plugins/limitless/limitless_plugin.go b/awssql/plugins/limitless/limitless_plugin.go index 505b3155..0d7acd9a 100644 --- a/awssql/plugins/limitless/limitless_plugin.go +++ b/awssql/plugins/limitless/limitless_plugin.go @@ -74,6 +74,10 @@ func NewLimitlessPluginWithRouterService(pluginService driver_infrastructure.Plu } } +func (plugin *LimitlessPlugin) GetPluginCode() string { + return driver_infrastructure.LIMITLESS_PLUGIN_CODE +} + func (plugin *LimitlessPlugin) GetSubscribedMethods() []string { return []string{plugin_helpers.CONNECT_METHOD} } diff --git a/awssql/plugins/read_write_splitting/read_write_splitting_plugin.go b/awssql/plugins/read_write_splitting/read_write_splitting_plugin.go index 9a4f24a3..ded1ed82 100644 --- a/awssql/plugins/read_write_splitting/read_write_splitting_plugin.go +++ b/awssql/plugins/read_write_splitting/read_write_splitting_plugin.go @@ -67,6 +67,10 @@ func NewReadWriteSplittingPlugin(pluginService driver_infrastructure.PluginServi } } +func (r *ReadWriteSplittingPlugin) GetPluginCode() string { + return driver_infrastructure.READ_WRITE_SPLITTING_PLUGIN_CODE +} + func (r *ReadWriteSplittingPlugin) GetSubscribedMethods() []string { return []string{plugin_helpers.CONNECT_METHOD, plugin_helpers.INIT_HOST_PROVIDER_METHOD, @@ -117,7 +121,6 @@ func (r *ReadWriteSplittingPlugin) Connect( } func (r *ReadWriteSplittingPlugin) InitHostProvider( - initialUrl string, props map[string]string, hostListProviderService driver_infrastructure.HostListProviderService, initHostProviderFunc func() error) error { diff --git a/awssql/property_util/aws_wrapper_property.go b/awssql/property_util/aws_wrapper_property.go index c207d723..db781c36 100644 --- a/awssql/property_util/aws_wrapper_property.go +++ b/awssql/property_util/aws_wrapper_property.go @@ -28,8 +28,16 @@ import ( const DEFAULT_PLUGINS = "failover,efm" const MONITORING_PROPERTY_PREFIX = "monitoring-" const LIMITLESS_PROPERTY_PREFIX = "limitless" -const INTERNAL_CONNECT_PROPERTY_NAME string = "76c06979-49c4-4c86-9600-a63605b83f50" -const SET_READ_ONLY_CTX_KEY string = "setReadOnly" +const INTERNAL_CONNECT_PROPERTY_NAME = "76c06979-49c4-4c86-9600-a63605b83f50" +const SET_READ_ONLY_CTX_KEY = "setReadOnly" +const BG_PROPERTY_PREFIX = "blue-green-monitoring-" + +var INTERNAL_PROPS_PREFIXES = []string{ + MONITORING_PROPERTY_PREFIX, + INTERNAL_CONNECT_PROPERTY_NAME, + LIMITLESS_PROPERTY_PREFIX, + BG_PROPERTY_PREFIX, +} type WrapperPropertyType int @@ -185,6 +193,13 @@ var ALL_WRAPPER_PROPERTIES = map[string]bool{ RESET_SESSION_STATE_ON_CLOSE.Name: true, ROLLBACK_ON_SWITCH.Name: true, READER_HOST_SELECTOR_STRATEGY.Name: true, + BG_CONNECT_TIMEOUT_MS.Name: true, + BGD_ID.Name: true, + BG_INTERVAL_BASELINE_MS.Name: true, + BG_INTERVAL_INCREASED_MS.Name: true, + BG_INTERVAL_HIGH_MS.Name: true, + BG_SWITCHOVER_TIMEOUT_MS.Name: true, + BG_SUSPEND_NEW_BLUE_CONNECTIONS.Name: true, } var USER = AwsWrapperProperty{ @@ -644,16 +659,73 @@ var READER_HOST_SELECTOR_STRATEGY = AwsWrapperProperty{ wrapperPropertyType: WRAPPER_TYPE_STRING, } +var BG_CONNECT_TIMEOUT_MS = AwsWrapperProperty{ + Name: "bgConnectTimeoutMs", + description: "Connect timeout in milliseconds during Blue/Green Deployment switchover.", + defaultValue: "30000", + wrapperPropertyType: WRAPPER_TYPE_INT, +} + +var BGD_ID = AwsWrapperProperty{ + Name: "bgdId", + description: "Blue/Green Deployment ID", + defaultValue: "1", + wrapperPropertyType: WRAPPER_TYPE_STRING, +} + +var BG_INTERVAL_BASELINE_MS = AwsWrapperProperty{ + Name: "bgBaselineMs", + description: "Baseline Blue/Green Deployment status checking interval in milliseconds.", + defaultValue: "60000", + wrapperPropertyType: WRAPPER_TYPE_INT, +} + +var BG_INTERVAL_INCREASED_MS = AwsWrapperProperty{ + Name: "bgIncreasedMs", + description: "Increased Blue/Green Deployment status checking interval in milliseconds.", + defaultValue: "1000", + wrapperPropertyType: WRAPPER_TYPE_INT, +} + +var BG_INTERVAL_HIGH_MS = AwsWrapperProperty{ + Name: "bgHighMs", + description: "High Blue/Green Deployment status checking interval in milliseconds.", + defaultValue: "100", + wrapperPropertyType: WRAPPER_TYPE_INT, +} + +var BG_SWITCHOVER_TIMEOUT_MS = AwsWrapperProperty{ + Name: "bgSwitchoverTimeoutMs", + description: "Blue/Green Deployment switchover timeout in milliseconds.", + defaultValue: "180000", + wrapperPropertyType: WRAPPER_TYPE_INT, +} + +var BG_SUSPEND_NEW_BLUE_CONNECTIONS = AwsWrapperProperty{ + Name: "bgSuspendNewBlueConnections", + description: "Enables Blue/Green Deployment switchover to suspend new blue connection requests while the switchover process is in progress.", + defaultValue: "false", + wrapperPropertyType: WRAPPER_TYPE_BOOL, +} + func RemoveInternalAwsWrapperProperties(props map[string]string) map[string]string { copyProps := map[string]string{} for key, value := range props { - // Monitoring properties and the internal connect property flag are not included in copy. - if !strings.HasPrefix(key, MONITORING_PROPERTY_PREFIX) && !strings.HasPrefix(key, LIMITLESS_PROPERTY_PREFIX) && - key != INTERNAL_CONNECT_PROPERTY_NAME { + // Properties that start with monitoring/internal connect prefixes are not included in copy. + if !startsWithPrefix(key) { copyProps[key] = value } } return copyProps } + +func startsWithPrefix(key string) bool { + for _, prefix := range INTERNAL_PROPS_PREFIXES { + if strings.HasPrefix(key, prefix) { + return true + } + } + return false +} diff --git a/awssql/resources/en.json b/awssql/resources/en.json index 55158acf..9171f1d9 100644 --- a/awssql/resources/en.json +++ b/awssql/resources/en.json @@ -1,219 +1,259 @@ { - "AdfsCredentialsProviderFactory.failedLogin": "Failed login. Could not obtain SAML Assertion from ADFS SignOn Page POST response.", - "AdfsCredentialsProviderFactory.invalidHttpsUrl": "Invalid HTTPS URL: '%s'.", - "AdfsCredentialsProviderFactory.signOnPagePostActionUrl": "ADFS SignOn Action URL: '%s'.", - "AdfsCredentialsProviderFactory.signOnPagePostActionRequestFailed": "ADFS SignOn Page POST action failed with HTTP status '%s'.", - "AdfsCredentialsProviderFactory.signOnPageRequestFailed": "ADFS SignOn Page Request Failed with HTTP status '%s'.", - "AdfsCredentialsProviderFactory.signOnPageUrl": "ADFS SignOn URL: '%s'.", - "AuthHelpers.missingRequiredParameters": "Missing required parameter(s) for plugin '%s': '%v'.", - "AwsClientHelper.errorGettingAwsCredentialsProvider": "Error occurred while getting aws.CredentialsProvider: '%v'.", - "AwsClientHelper.errorGettingClientConfig": "Error occurred while loading configuration for aws client: '%v'.", - "AwsSecretsManagerConnectionPlugin.endpointOverrideMisconfigured" : "The provided endpoint is invalid and could not be used to create a URI: '%v'.", - "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials" : "Was not able to either fetch or read the database credentials from AWS Secrets Manager. Ensure the correct secretId and region properties have been provided.", - "AwsSecretsManagerConnectionPlugin.invalidRegion": "Invalid AWS Secrets Manager Region was given: '%s'.", - "AwsSecretsManagerConnectionPlugin.secretIdMissing": "A secret id or a secret arn must be provided in the '%v' property.", - "AwsSecretsManagerConnectionPlugin.unableToCreateAwsSecretsManagerClient": "Error occurred while initializing the AwsSecretsManager client", - "AwsSecretsManagerConnectionPlugin.unableToDetermineRegion": "Unable to determine connection region. If you are not providing a secret ARN, please set the '%v' property.", - "AwsSecretsManagerConnectionPlugin.unableToGetSecretValue" : "Error occurred while getting secret value from AwsSecretsManager: '%v'. ", - "AwsSecretsManagerConnectionPlugin.unableToParseSecretValue" : "Error occurred while parsing the secret value from AwsSecretsManager: '%v'. Make sure that 'username' and 'password' entries are present in the secrets.", - "AwsSecretsManagerConnectionPlugin.useCachedSecret" : "Use cached secret.", - "AwsWrapper.illegalArgumentError": "Illegal argument for property '%s' in method '%s' was provided: '%s'.", - "AwsWrapper.initializingDatabaseHandle": "Initializing database handler with the following properties: '%v'.", - "AwsWrapper.unsupportedMethodError": "Method '%v' not supported by %s.", - "AwsWrapperExecuteWithPlugins.unableToCastResult": "Returned result does not match expected type '%v'.", - "AwsWrapperProperty.unexpectedType": "Value of property '%v' was not the expected type. Received: '%v'. Returning zero value.", - "AwsWrapperProperty.requiresNonNegativeIntValue": "Value of integer property '%v' is negative. Ensure a non-negative property value is provided for intended wrapper behaviour.", - "AwsWrapperProperty.noTimeoutValue": "Property '%v' given value of '%v'. This results in no timeout enforcement and can lead to hanging requests in authentication flows. Please provide a positive timeout value if this is not intended.", - "AwsWrapperProperty.noExpirationValue": "Property '%v' given value of '%v'. Saved resources will expire before they can be used, and will be regenerated each time they are required. Please provide a positive expiration value if this is not intended.", - "AwsWrapperProperty.noRefreshRateValue": "Property '%v' given value of '%v'. Topology/routers will be continuously fetched. Please provide a positive value if this is not intended.", - "AwsWrapperRows.underlyingRowsDoNotImplementRequiredInterface": "The underlying rows do not implement the required interface '%v'.", - "AwsWrapperStmt.underlyingStmtDoesNotImplementRequiredInterface": "The underlying driver statement does not implement the required interface '%v'.", - "AuthenticationToken.generatedNewToken": "Generated new authentication token.", - "AuthenticationToken.useCachedToken": "Use cached authentication token.", - "ClusterTopologyMonitorImpl.errorFetchingTopology": "An error occurred while querying for topology: '%s'.", - "ClusterTopologyMonitorImpl.ignoringTopologyRequest": "A topology refresh was requested, but the topology was already updated recently. Returning cached hosts:", - "ClusterTopologyMonitorImpl.openedMonitoringConnection": "Opened monitoring connection to host '%s'.", - "ClusterTopologyMonitorImpl.startingHostMonitoringRoutines": "Starting host monitoring routines.", - "ClusterTopologyMonitorImpl.startMonitoringRoutine": "Start cluster topology monitoring routine for '%s'.", - "ClusterTopologyMonitorImpl.timeoutSetToZero": "A topology refresh was requested, but the given timeout for the request was 0ms. Returning cached hosts:", - "ClusterTopologyMonitorImpl.topologyNotUpdated": "Topology hasn't been updated after %v ms.", - "ClusterTopologyMonitorImpl.writerMonitoringConnection": "The monitoring connection is connected to a writer: '%s'.", - "ClusterTopologyMonitorImpl.writerPickedUpFromHostMonitors": "The writer host detected by the host monitors was picked up by the topology monitor: '%s'.", - "Conn.doesNotImplementRequiredInterface": "The given connection does not implement the required interface '%v'.", - "Conn.invalidTransactionIsolationLevel": "An invalid transaction isolation level was provided: '%v'.", - "ConnectionPluginManager.unknownPluginCode": "Unknown plugin code: '%s'. Please ensure all plugin codes are valid and any required plugin modules have been imported.", - "ConnectionProvider.unsupportedHostSelectorStrategy": "Unsupported host selection strategy '%v' specified for this connection provider '%T'. Please visit the documentation for all supported strategies.", - "DatabaseDialect.invalidTransactionIsolationLevel": "An invalid transaction isolation level was provided: '%s'.", - "DatabaseDialect.usingMonitoringHostListProvider": "Failover is enabled. Using MonitoringRdsHostListProvider.", - "DatabaseDialect.usingRdsHostListProvider": "Failover is not enabled. Using RdsHostListProvider.", - "DatabaseDialectManager.missingWrapperDriver": "The AWS Advanced Go Wrapper driver name '%s' has not been registered. Please ensure any required driver modules have been imported.", - "DatabaseDialectManager.invalidDriverProtocol": "Invalid driver protocol from properties: %v.", - "DatabaseDialectManager.getDialectError": "Was not able to get a database dialect.", - "DatabaseDialectManager.unknownDialectCode": "Unknown database dialect code: %v.", - "DefaultConnectionPlugin.noHostsAvailable": "The default connection plugin received an empty host list from the plugin service.", - "DefaultTelemetryFactory.telemetryFactoryUnavailable": "Telemetry factory from code '%s' could not be found. Please ensure the corresponding telemetry module has been imported.", - "DefaultTelemetryFactory.missingTelemetryFactory": "Missing telemetry factory: '%s'.", - "DefaultTelemetryFactory.invalidBackend": "Invalid telemetry backend: '%s'.", - "Driver.connectionNotOpen": "Initial connection isn't open.", - "Driver.missingUnderlyingDriverOrDialect": "The underlying driver or driver dialect could not be found.", - "DsnHostListProvider.parsedListEmpty": "Can't parse dsn.", - "DsnHostListProvider.unsupportedGetClusterId": "DsnHostListProvider does not support GetClusterId.", - "DsnHostListProvider.unsupportedGetHostRole": "DsnHostListProvider does not support GetHostRole", - "DsnHostListProvider.unsupportedIdentifyConnection": "DsnHostListProvider does not support IdentifyConnection.", - "DsnParser.failedToSplitHostPort": "Failed to split host:port in '%s', err: %w", - "DsnParser.invalidAddress": "Invalid address from DSN string: '%v'.", - "DsnParser.invalidBackslash": "Invalid backslash found in DSN string: '%v'.", - "DsnParser.invalidDatabaseNoSlash": "Invalid DSN with no slash separating the database name. DSN string: '%v'.", - "DsnParser.invalidKeyValue": "Invalid key value from DSN string: '%v'.", - "DsnParser.unableToDetermineProtocol": "Unable to determine protocol of DSN string: '%v'.", - "DsnParser.unterminatedQuotedString": "Unterminated quoted string in DSN string: '%v'.", - "DsnParser.unableToMatchPortsToHosts": "Unable to match ports to hosts in DSN string. Only DSNs with one port are supported.", - "ExecutionTimePlugin.executionTime": "Executed method '%v' in '%v' milliseconds.", - "Failover.connectionChangedError": "The active SQL connection has changed due to a connection failure. Please re-configure session state if required.", - "Failover.connectionClosedExplicitly": "Unable to failover, the connection has been explicitly closed.", - "Failover.detectedError": "Detected an error while executing a command: '%s'.", - "Failover.errorSelectingReaderHost": "An error occurred while attempting to select a reader host candidate: '%s'. Candidates:", - "Failover.errorConnectingToWriter": "An error occurred while trying to connect to the new writer '%s'.", - "Failover.establishedConnection": "Connected to: '%s'.", - "Failover.failedReaderConnection": "[Reader Failover] Failed to connect to host: '%s'", - "Failover.failoverDisabled": "Cluster-aware failover is disabled.", - "Failover.failoverReaderTimeout": "The reader failover process was not able to establish a connection before timing out.", - "Failover.failoverReaderUnableToRefreshHostList": "The request to discover the new topology was unsuccessful.", - "Failover.noOperationsAfterConnectionClosed": "No operations allowed after connection closed.", - "Failover.noWriterHost": "Unable to find writer in updated host list:", - "Failover.parameterValue": "Failover parameter: %s=%s.", - "Failover.readerCandidateNil": "Reader candidate is and unable to be selected.", - "Failover.readerFailoverElapsed": "Reader failover elapsed in %v.", - "Failover.strictReaderUnknownHostRole": "Unable to determine host role for '%s'. Since failover mode is set to STRICT_READER and the host may be a writer, it will not be selected for reader failover.", - "Failover.startReaderFailover": "Starting reader failover procedure.", - "Failover.startWriterFailover": "Starting writer failover procedure.", - "Failover.timeoutError": "Internal failover task has timed out.", - "Failover.transactionResolutionUnknownError": "Transaction resolution unknown. Please re-configure session state if required and try restarting the transaction.", - "Failover.unableToRefreshHostList": "The request to discover the new topology timed out or was unsuccessful.", - "Failover.unableToConnect": "Unable to establish a SQL connection due to an unexpected error.", - "Failover.unableToConnectToReader": "Unable to establish SQL connection to the reader instance.", - "Failover.unexpectedReaderRole": "The new writer was identified to be '%s', but querying the instance for its role returned a role of '%s'.", - "Failover.writerFailoverElapsed": "Writer failover elapsed in %v.", - "FederatedAuthPlugin.unableToDetermineRegion": "Unable to determine connection region. If you are using a non-standard RDS URL, please set the '%s' property.", - "HighestWeightHostSelector.noHostsMatchingRole": "No available hosts were found matching the requested '%v' role.", - "HostInfoBuilder.InvalidEmptyHost": "HostInfoBuilder Host parameter must be set.", - "HostMonitoringConnectionPlugin.activatedMonitoring": "Executing method %v, monitoring is activated.", - "HostMonitoringConnectionPlugin.clusterHostInfoRequired": "Monitoring HostInfo is associated with a cluster endpoint, plugin needs to identify the cluster connection.", - "HostMonitoringConnectionPlugin.errorIdentifyingConnection": "Error occurred while identifying connection: %v.", - "HostMonitoringConnectionPlugin.errorGettingMonitoringHostInfo": "Error occurred while getting monitoring HostInfo: %v.", - "HostMonitoringConnectionPlugin.illegalArgumentError": "Illegal argument for '%v' was given to the HostMonitoringConnectionPlugin.", - "HostMonitoringConnectionPlugin.monitoringDeactivated": "Monitoring deactivated for method %v.", - "HostMonitoringConnectionPlugin.unableToIdentifyConnection": "Unable to identify the given connection %v : %v. Please ensure the correct host list provider is specified.", - "HostMonitoringConnectionPlugin.unavailableHost": "Host '%v' is unavailable.", - "HostMonitoringRoutine.detectedWriter": "Writer detected by host monitoring routine: '%s'.", - "HostMonitoringRoutine.routineCompleted": "Host monitoring routine completed in %v.", - "HostMonitoringRoutine.writerHostChanged": "Writer host changed from '%s' to host '%s'.", - "HostSelector.noHostsMatchingRole": "No available hosts were found matching the requested '%v' role.", - "HostSelector.invalidHostWeightPairs": "The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of :. Weight values must be an integer greater than or equal to the default weight value of 1.", - "IamAuthPlugin.connectionError": "Error occurred while opening a connection: '%v'.", - "IamAuthPlugin.errorGeneratingNewToken": "Error occurred while generating authentication token: '%v'.", - "IamAuthPlugin.errorGettingAwsCredentialsProvider": "Error occurred while getting aws.CredentialsProvider: '%v'.", - "IamAuthPlugin.unableToDetermineRegion": "Unable to determine connection region. If you are using a non-standard RDS URL, please set the '%v' property.", - "IamAuthPlugin.useCachedToken": "Using cached authentication token.", - "InternalPooledConn.UnsupportedOperation": "The underlying driver does not implement '%v'.", - "InternalPooledConn.CannotResetSession": "Could not reset session for internal pool conn. Closing the connection.", - "LimitlessPlugin.expectedShardGroupUrl": "The provided host was not a Limitless DB shard group URL: '%s'.", - "LimitlessPlugin.failedToConnectToHost": "Limitless Plugin failed to connect to host %v.", - "LimitlessPlugin.invalidDatabaseUrl": "Invalid Limitless Database URL '%v'. Please use a valid Limitless DB Shard Group endpoint URL.", - "LimitlessPlugin.unsupportedDialectOrDatabase": "Unsupported dialect '%T' encountered. Please ensure connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin.", - "LimitlessQueryHelperImpl.unableToFetchRouterName": "Unable to fetch Limitless Router Name.", - "LimitlessQueryHelperImpl.unableToFetchFetchingRouterLoad": "Unable to fetch Limitless Router Load.", - "LimitlessQueryHelperImpl.invalidDatabaseDialect": "Unable to fetch Limitless Routers due to invalid database dialect: '%T'. Please ensure that the connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin.", - "LimitlessQueryHelperImpl.invalidQuery": "Limitless Connection Plugin has encountered an error obtaining Limitless Router endpoints. Please ensure that you are connecting to an Aurora Limitless Database Shard Group Endpoint URL.", - "LimitlessQueryHelperImpl.invalidRouterLoad": "Invalid load metric value of %v from the transaction router query aurora_limitless_router_endpoints() for transaction router '%v'. The load metric value must be a decimal value between 0 and 1. Host weight will be assigned a default weight of 1.", - "LimitlessRouterMonitorImpl.closeMonitoring": "Closing Limitless monitoring for router: %v.", - "LimitlessRouterServiceImpl.errorStartingMonitor": "An error occurred while starting Limitless Router Monitor: '%v'.", - "LimitlessRouterServiceImpl.failedToConnectToRouter": "Failed to connect to Limitless router: %v.", - "LimitlessRouterMonitorImpl.hostSelectorStrategyNotFound": "Limitless Router Monitor unable to fetch host selector strategy and update host selector strategy weights.", - "LimitlessRouterServiceImpl.limitlessRouterCacheEmpty": "Limitless Router cache is empty. This is normal during application start up when the cache is not yet populated.", - "LimitlessRouterServiceImpl.maxConnectRetriesExceeded": "Max Limitless connection retries has been exceeded. Unable to connect to any transaction routers.", - "LimitlessRouterServiceImpl.noRoutersAvailable": "Unable to connect to any Limitless transaction routers", - "LimitlessRouterMonitorImpl.noRoutersFetched": "Unable to discover any Limitless transaction routers.", - "LimitlessRouterMonitorImpl.openedConnection": "Opened Limitless Router Monitor connection to %v.", - "LimitlessRouterMonitorImpl.openingConnection": "Opening Limitless Router Monitor connection to %v.", - "LimitlessRouterServiceImpl.selectedHost": "Limitless Router '%v' has been selected.", - "LimitlessRouterMonitorImpl.startMonitoring": "Starting Limitless monitoring for router: %v.", - "LimitlessRouterMonitorImpl.stopMonitoring": "Stopped Limitless monitoring for router: %v.", - "LimitlessRouterMonitorImpl.unableToCastHostSelector": "Limitless Router Monitor unable to cast fetched host selector strategy and update host selector weights.", - "LimitlessRouterServiceImpl.unableToConnectNoRoutersAvailable": "Unable to connect to original host %v. All transaction routers are unavailable. Please verify connection credentials and network connectivity.", - "MonitorImpl.stopped": "Stopped monitoring routine for host %v.", - "MonitorImpl.monitorIsStopped": "Monitoring was already stopped for host %v.", - "MonitorImpl.startMonitoringRoutineNewState": "Start monitoring routine checking for new states for %v.", - "MonitorImpl.stopMonitoringRoutineNewState": "Stop monitoring routine checking for new states for %v.", - "MonitorImpl.startMonitoringRoutine": "Start monitoring routine for %v.", - "MonitorImpl.stopMonitoringRoutine": "Stop monitoring routine for %v.", - "MonitorImpl.openingMonitoringConnection": "Opening monitoring connection to %v.", - "MonitorImpl.openedMonitoringConnection": "Opened monitoring connection to %v.", - "MonitorImpl.hostNotResponding": "Host %v is not responding.", - "MonitorImpl.hostDead": "Host %v is dead.", - "MonitorImpl.hostAlive": "Host %v is alive.", - "MonitorImpl.updatingActiveStates": "Updating active states for %v. Going from %v state(s) to %v.", - "MonitorServiceImpl.errorAbortingConn": "An error was thrown when aborting monitoring Conn: %v.", - "MonitorServiceImpl.illegalArgumentError": "Illegal argument for '%v' was given to MonitorServiceImpl.", - "OktaAuthPlugin.httpNon200StatusCode": "Did not get a 200 response from http request to '%v'. Received the following status code: '%v'.", - "OktaAuthPlugin.unableToDetermineRegion": "Unable to determine the region. Please set the '%v' parameter.", - "OktaAuthPlugin.unableToRetrieveSessionToken": "Could not retrieve session token from endpoint.", - "OktaAuthPlugin.failedSamlAssertion": "Could not get SAML response from the following endpoint: '%v'.", - "PluginManager.pipelineNone": "A pipeline was requested but the created pipeline evaluated to nil.", - "PluginManager.unknownPluginCode": "Unknown plugin code: '%v'.", - "PluginManagerImpl.releaseResources": "Releasing resources from PluginManagerImpl.", - "PluginManagerImpl.unsupportedHostSelectionStrategy": "The wrapper does not support the requested host selection strategy: %v.", - "PluginServiceImpl.releaseResources": "Releasing resources from PluginServiceImpl.", - "PluginServiceImpl.requiredBlockingHostListProvider": "The detected host list provider is not a BlockingHostListProvider. A BlockingHostListProvider is required to force refresh the host list. Detected host list provider: '%T'.", - "PluginServiceImpl.nilHost": "Current host evaluates to nil.", - "PluginServiceImpl.initialHostNotSet": "Unable to update dialect, initial HostInfo has not been set.", - "PluginServiceImpl.nilConn": "Unable to set current connection, given connection evaluates to nil.", - "PluginServiceImpl.setCurrentHost": "Set current host to '%v'.", - "PluginServiceImpl.hostListEmpty": "Host list is empty.", - "PluginServiceImpl.hostsChangelistEmpty": "There are no changes in the hosts' availability.", - "PluginServiceImpl.nonEmptyAliases": "FillAliases called when HostInfo already contains the following aliases: '%v'.", - "PluginServiceImpl.failedToRetrieveHostPort": "Could not retrieve Host:Port for connection.", - "PluginManagerImpl.invokedAgainstOldConnection": "The internal connection has changed since %v was created, skip executing method %v. This is likely due to failover. To ensure you are using the updated connection, please re-create Statement, Tx, Result and Row objects after failover.", - "ReadWriteSplittingPlugin.couldNotRefreshHostlist": "Could not refresh host list", - "ReadWriteSplittingPlugin.emptyHostList": "Host list is empty", - "ReadWriteSplittingPlugin.errorSwitchingToCachedReader": "An error occurred while trying to switch to a cached reader connection: '%s'. The driver will attempt to establish a new reader connection.", - "ReadWriteSplittingPlugin.errorSwitchingToReader": "An error occured while trying to switch to a reader connection: '%v'", - "ReadWriteSplittingPlugin.errorSwitchingToWriter": "An error occured while trying to switch to the writer connection", - "ReadWriteSplittingPlugin.errorVerifyingInitialHostRole": "An error occurred while obtaining the connected host's role. This could occur if the client is broken or if you are not connected to an Aurora database.", - "ReadWriteSplittingPlugin.errorWhileExecutingCommand": "[ReadWriteSplitting] Detected an error while executing a command: '%s'", - "ReadWriteSplittingPlugin.failedToConnectToReader": "Failed to connect to reader host: '%s'", - "ReadWriteSplittingPlugin.failoverErrorWhileExecutingCommand": "Detected a failover error while executing a command: '%s'", - "ReadWriteSplittingPlugin.fallbackToWriter": "Failed to switch to reader '%v'. The current writer '%s' will be used as fallback", - "ReadWriteSplittingPlugin.noReadersAvailable": "The plugin was unable to establish a reader connection to any reader instance", - "ReadWriteSplittingPlugin.noWriterFound": "No writer was found in the current host list. This may occur if the writer is not in the list of allowed hosts", - "ReadWriteSplittingPlugin.setReadOnlyOnClosedConnection": "Cannot set ReadOnly on closed connection", - "ReadWriteSplittingPlugin.setReadOnlyFalseInTransaction": "readOnly was set to true during a transaction. Please complete the transaction before setting readOnly to true", - "ReadWriteSplittingPlugin.settingCurrentConnection": "Setting the current connection to '%s'", - "ReadWriteSplittingPlugin.switchedFromReaderToWriter": "Switched from a reader to a writer host. New writer host: '%s'", - "ReadWriteSplittingPlugin.switchedFromWriterToReader": "Switched from a writer to a reader host. New reader host: '%s'", - "ReadWriteSplittingPlugin.unsupportedHostSelectorStrategy": "Unsupported host selection strategy '%s', specified in plugin configuration parameter '%s'. Please visit the Read/Write Splitting documentation for all supported strategies", - "ReadWriteSplittingPlugin.updateInternalConnectionInfoFailed": "Cannot update internal connection. Error: '%s'", - "RdsHostListProvider.unableToGatherTopology": "Unable to gather topology, no hosts identified from query, cache, or initial host list.", - "RdsHostListProvider.unableToGetHostName": "Unable to retrieve host name from given connection.", - "RdsHostListProvider.givenTemplateInvalid": "The given cluster format template is invalid, using default host formatting.", - "RdsHostListProvider.suggestedClusterId": "ClusterId '%v' is suggested for url '%v'.", - "RdsHostListProvider.unknownHostRole": "Query to gather host role failed, unable to determine role of current host.", - "SamlCredentialsProviderFactory.getSamlAssertionFailed": "Failed to get SAML Assertion due to exception: '%s'.", - "SessionStateService.logState": "Current session state: \n%s", - "SessionStateService.transferIncomplete": "Previous session state transfer has not been completed.", - "SlidingExpirationCache.exitingCacheCleanupRoutine": "Sliding expiration cache '%v' cleanup routine has been cancelled.", - "SlidingExpirationCache.itemDisposal": "Disposing of %s.", - "SlidingExpirationCache.startingCacheCleanupRoutine": "Sliding expiration cache '%v', has been initialized, cleanup routine has started.", - "StaleDnsHelper.clusterEndpointDns": "Cluster endpoint resolves to '%s'.", - "StaleDnsHelper.staleDnsDetected": "Stale DNS data detected. Opening a connection to '%s'.", - "StaleDnsHelper.writerHostInfo": "Writer host: '%s'.", - "StaleDnsHelper.writerInetAddress" : "Writer host address: '%s'.", - "TargetDriverHelper.invalidProtocol": "Invalid database protocol was found: %v", - "TargetDriverHelper.missingDriver": "Cannot find the target driver for %v. Please ensure the target driver is imported and registered. Here is the list of registered drivers found: %v", - "TelemetryContext.castError": "Could not cast provided TelemetryContext to type '%s'.", - "Utils.failedToReadHttpResponse": "Failed to read HTTP response body: '%s'.", - "Utils.rollbackError": "Error occurred when attempting rollback: '%s'.", - "Utils.topology": "%s \n%s", - "WeightedRandomHostSelector.noHostsMatchingRole": "Weighted Random strategy was unable to select a host. No available hosts were found matching the requested '%v' role.", - "WeightedRandomHostSelector.unableToGetHost": "Weighted Random strategy was unable to select a host." + "AdfsCredentialsProviderFactory.failedLogin": "Failed login. Could not obtain SAML Assertion from ADFS SignOn Page POST response.", + "AdfsCredentialsProviderFactory.invalidHttpsUrl": "Invalid HTTPS URL: '%s'.", + "AdfsCredentialsProviderFactory.signOnPagePostActionRequestFailed": "ADFS SignOn Page POST action failed with HTTP status '%s'.", + "AdfsCredentialsProviderFactory.signOnPagePostActionUrl": "ADFS SignOn Action URL: '%s'.", + "AdfsCredentialsProviderFactory.signOnPageRequestFailed": "ADFS SignOn Page Request Failed with HTTP status '%s'.", + "AdfsCredentialsProviderFactory.signOnPageUrl": "ADFS SignOn URL: '%s'.", + "AuthHelpers.missingRequiredParameters": "Missing required parameter(s) for plugin '%s': '%v'.", + "AuthenticationToken.generatedNewToken": "Generated new authentication token.", + "AuthenticationToken.useCachedToken": "Use cached authentication token.", + "AwsClientHelper.errorGettingAwsCredentialsProvider": "Error occurred while getting aws.CredentialsProvider: '%v'.", + "AwsClientHelper.errorGettingClientConfig": "Error occurred while loading configuration for aws client: '%v'.", + "AwsSecretsManagerConnectionPlugin.endpointOverrideMisconfigured": "The provided endpoint is invalid and could not be used to create a URI: '%v'.", + "AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials": "Was not able to either fetch or read the database credentials from AWS Secrets Manager. Ensure the correct secretId and region properties have been provided.", + "AwsSecretsManagerConnectionPlugin.invalidRegion": "Invalid AWS Secrets Manager Region was given: '%s'.", + "AwsSecretsManagerConnectionPlugin.secretIdMissing": "A secret id or a secret arn must be provided in the '%v' property.", + "AwsSecretsManagerConnectionPlugin.unableToCreateAwsSecretsManagerClient": "Error occurred while initializing the AwsSecretsManager client", + "AwsSecretsManagerConnectionPlugin.unableToDetermineRegion": "Unable to determine connection region. If you are not providing a secret ARN, please set the '%v' property.", + "AwsSecretsManagerConnectionPlugin.unableToGetSecretValue": "Error occurred while getting secret value from AwsSecretsManager: '%v'. ", + "AwsSecretsManagerConnectionPlugin.unableToParseSecretValue": "Error occurred while parsing the secret value from AwsSecretsManager: '%v'. Make sure that 'username' and 'password' entries are present in the secrets.", + "AwsSecretsManagerConnectionPlugin.useCachedSecret": "Use cached secret.", + "AwsWrapper.illegalArgumentError": "Illegal argument for property '%s' in method '%s' was provided: '%s'.", + "AwsWrapper.initializingDatabaseHandle": "Initializing database handler with the following properties: '%v'.", + "AwsWrapper.unsupportedMethodError": "Method '%v' not supported by %s.", + "AwsWrapperExecuteWithPlugins.unableToCastResult": "Returned result does not match expected type '%v'.", + "AwsWrapperProperty.noExpirationValue": "Property '%v' given value of '%v'. Saved resources will expire before they can be used, and will be regenerated each time they are required. Please provide a positive expiration value if this is not intended.", + "AwsWrapperProperty.noRefreshRateValue": "Property '%v' given value of '%v'. Topology/routers will be continuously fetched. Please provide a positive value if this is not intended.", + "AwsWrapperProperty.noTimeoutValue": "Property '%v' given value of '%v'. This results in no timeout enforcement and can lead to hanging requests in authentication flows. Please provide a positive timeout value if this is not intended.", + "AwsWrapperProperty.requiresNonNegativeIntValue": "Value of integer property '%v' is negative. Ensure a non-negative property value is provided for intended wrapper behaviour.", + "AwsWrapperProperty.unexpectedType": "Value of property '%v' was not the expected type. Received: '%v'. Returning zero value.", + "AwsWrapperRows.underlyingRowsDoNotImplementRequiredInterface": "The underlying rows do not implement the required interface '%v'.", + "AwsWrapperStmt.underlyingStmtDoesNotImplementRequiredInterface": "The underlying driver statement does not implement the required interface '%v'.", + "BlueGreenDeployment.allGreenHostChangedName": "All green hosts have changed name.", + "BlueGreenDeployment.bgIdRequired": "Unable to initialize Blue/Green plugin, value for the property bgId must be set.", + "BlueGreenDeployment.blueDnsCompleted": "[bgdId: '%v'] Blue DNS update completed.", + "BlueGreenDeployment.completedContinueWithConnect": "Blue/Green Deployment status is completed. Continue with 'connect' call. The call was held for '%v' ms.", + "BlueGreenDeployment.correspondingHostFoundContinueWithConnect": "A corresponding host for '%v' is found. Continue with connect call. The call was held for '%v' ms.", + "BlueGreenDeployment.correspondingHostNotFoundTryConnectLater": "Blue/Green Deployment switchover is still in progress and a corresponding host for '%v' is not found after '%v' ms. Try to connect again later.", + "BlueGreenDeployment.createHostListProvider": "[%v] Creating a new HostListProvider, clusterId: %v.", + "BlueGreenDeployment.errorGeneratingHash": "Unable to generate hash. Error: %v.", + "BlueGreenDeployment.errorQueryingStatusTable": "Unable to retrieve status table. Error: %v.", + "BlueGreenDeployment.greenDnsRemoved": "[bgdId: '%v'] Green DNS removed.", + "BlueGreenDeployment.greenHostChangedName": "Green host '%v' has changed names, using IAM host '%v'.", + "BlueGreenDeployment.greenTopologyChanged": "[bgdId: '%v'] Green topology changed.", + "BlueGreenDeployment.inProgressCantConnect": "Blue/Green Deployment switchover is in progress. New connection can not be opened.", + "BlueGreenDeployment.inProgressCantOpenConnection": "Blue/Green Deployment switchover is in progress. Can't establish connection to '%v'.", + "BlueGreenDeployment.inProgressConnectionClosed": "Connection has been closed since Blue/Green switchover is in progress.", + "BlueGreenDeployment.inProgressSuspendConnect": "Blue/Green Deployment switchover is in progress. The 'connect' call will be delayed until switchover is completed.", + "BlueGreenDeployment.inProgressSuspendMethod": "Blue/Green Deployment switchover is in progress. Suspend '%v' call until switchover is completed.", + "BlueGreenDeployment.inProgressTryConnectLater": "Blue/Green Deployment switchover is still in progress after '%v' ms. Try to connect again later.", + "BlueGreenDeployment.inProgressTryMethodLater": "Blue/Green Deployment switchover is still in progress after '%v' ms. Try '%v' again later.", + "BlueGreenDeployment.interimStatus": "[bgdId: '%v', role: %v] %v", + "BlueGreenDeployment.monitoringLoopCompleted": "[%v] Blue/green status monitoring loop is completed.", + "BlueGreenDeployment.noEntriesInStatusTable": "[%v] No entries in status table.", + "BlueGreenDeployment.openedConnection": "[%v] Opened monitoring connection to %v.", + "BlueGreenDeployment.openedConnectionWithIp": "[%v] Opened monitoring connection (IP) to %v.", + "BlueGreenDeployment.openingConnection": "[%v] Opening monitoring connection to %v.", + "BlueGreenDeployment.openingConnectionWithIp": "[%v] Opening monitoring connection (IP) to %v.", + "BlueGreenDeployment.requireIamHost": "Connecting with IP address when IAM authentication is enabled requires an ''iamHost'' parameter.", + "BlueGreenDeployment.resetContext": "Blue Green Status Provider resetting context.", + "BlueGreenDeployment.rollback": "[bgdId: '%v'] Blue/Green deployment is in rollback mode.", + "BlueGreenDeployment.statusChanged": "[%v] Status changed to: '%v'.", + "BlueGreenDeployment.statusNotAvailable": "[%v] (status not available) currentPhase: %v.", + "BlueGreenDeployment.switchoverCompleteContinueWithConnect": "Blue/Green Deployment switchover is completed. Continue with connect call. The call was held for '%v' ms.", + "BlueGreenDeployment.switchoverCompletedContinueWithMethod": "Blue/Green Deployment switchover is completed. Continue with '%v' call. The call was held for '%v' ms.", + "BlueGreenDeployment.switchoverTimeout": "Blue/Green switchover has timed out.", + "BlueGreenDeployment.unknownPhase": "[bgdId: '%v'] Unknown BG phase '%v'.", + "BlueGreenDeployment.unknownRole": "Unknown blue/green role: '%s'.", + "BlueGreenDeployment.unknownStatus": "Unknown blue/green status: '%s'.", + "BlueGreenDeployment.unknownVersion": "Unknown blue/green version '%v'.", + "BlueGreenDeployment.unsupportedDialect": "[bgdId: '%v'] Blue/Green Deployments isn't supported by database dialect '%v'.", + "BlueGreenDeployment.waitConnectUntilCorrespondingHostFound": "Blue/Green Deployment switchover is in progress and a corresponding host for '%v' is not found. The ''connect'' call will be delayed.", + "ClusterTopologyMonitorImpl.errorFetchingTopology": "An error occurred while querying for topology: '%s'.", + "ClusterTopologyMonitorImpl.ignoringTopologyRequest": "A topology refresh was requested, but the topology was already updated recently. Returning cached hosts:", + "ClusterTopologyMonitorImpl.openedMonitoringConnection": "Opened monitoring connection to host '%s'.", + "ClusterTopologyMonitorImpl.startMonitoringRoutine": "Start cluster topology monitoring routine for '%s'.", + "ClusterTopologyMonitorImpl.startingHostMonitoringRoutines": "Starting host monitoring routines.", + "ClusterTopologyMonitorImpl.timeoutSetToZero": "A topology refresh was requested, but the given timeout for the request was 0ms. Returning cached hosts:", + "ClusterTopologyMonitorImpl.topologyNotUpdated": "Topology hasn't been updated after %v ms.", + "ClusterTopologyMonitorImpl.writerMonitoringConnection": "The monitoring connection is connected to a writer: '%s'.", + "ClusterTopologyMonitorImpl.writerPickedUpFromHostMonitors": "The writer host detected by the host monitors was picked up by the topology monitor: '%s'.", + "Conn.doesNotImplementRequiredInterface": "The given connection does not implement the required interface '%v'.", + "Conn.invalidTransactionIsolationLevel": "An invalid transaction isolation level was provided: '%v'.", + "ConnectionPluginManager.unknownPluginCode": "Unknown plugin code: '%s'. Please ensure all plugin codes are valid and any required plugin modules have been imported.", + "ConnectionProvider.unsupportedHostSelectorStrategy": "Unsupported host selection strategy '%v' specified for this connection provider '%T'. Please visit the documentation for all supported strategies.", + "DatabaseDialect.invalidTransactionIsolationLevel": "An invalid transaction isolation level was provided: '%s'.", + "DatabaseDialect.usingMonitoringHostListProvider": "Failover is enabled. Using MonitoringRdsHostListProvider.", + "DatabaseDialect.usingRdsHostListProvider": "Failover is not enabled. Using RdsHostListProvider.", + "DatabaseDialectManager.getDialectError": "Was not able to get a database dialect.", + "DatabaseDialectManager.invalidDriverProtocol": "Invalid driver protocol from properties: %v.", + "DatabaseDialectManager.missingWrapperDriver": "The AWS Advanced Go Wrapper driver name '%s' has not been registered. Please ensure any required driver modules have been imported.", + "DatabaseDialectManager.unknownDialectCode": "Unknown database dialect code: %v.", + "DefaultConnectionPlugin.noHostsAvailable": "The default connection plugin received an empty host list from the plugin service.", + "DefaultTelemetryFactory.invalidBackend": "Invalid telemetry backend: '%s'.", + "DefaultTelemetryFactory.missingTelemetryFactory": "Missing telemetry factory: '%s'.", + "DefaultTelemetryFactory.telemetryFactoryUnavailable": "Telemetry factory from code '%s' could not be found. Please ensure the corresponding telemetry module has been imported.", + "Driver.connectionNotOpen": "Initial connection isn't open.", + "Driver.missingUnderlyingDriverOrDialect": "The underlying driver or driver dialect could not be found.", + "DsnHostListProvider.parsedListEmpty": "Can't parse dsn.", + "DsnHostListProvider.unsupportedGetClusterId": "DsnHostListProvider does not support GetClusterId.", + "DsnHostListProvider.unsupportedGetHostRole": "DsnHostListProvider does not support GetHostRole", + "DsnHostListProvider.unsupportedIdentifyConnection": "DsnHostListProvider does not support IdentifyConnection.", + "DsnParser.failedToSplitHostPort": "Failed to split host:port in '%s', err: %w", + "DsnParser.invalidAddress": "Invalid address from DSN string: '%v'.", + "DsnParser.invalidBackslash": "Invalid backslash found in DSN string: '%v'.", + "DsnParser.invalidDatabaseNoSlash": "Invalid DSN with no slash separating the database name. DSN string: '%v'.", + "DsnParser.invalidKeyValue": "Invalid key value from DSN string: '%v'.", + "DsnParser.unableToDetermineProtocol": "Unable to determine protocol of DSN string: '%v'.", + "DsnParser.unableToMatchPortsToHosts": "Unable to match ports to hosts in DSN string. Only DSNs with one port are supported.", + "DsnParser.unterminatedQuotedString": "Unterminated quoted string in DSN string: '%v'.", + "ExecutionTimePlugin.executionTime": "Executed method '%v' in '%v' milliseconds.", + "Failover.connectionChangedError": "The active SQL connection has changed due to a connection failure. Please re-configure session state if required.", + "Failover.connectionClosedExplicitly": "Unable to failover, the connection has been explicitly closed.", + "Failover.detectedError": "Detected an error while executing a command: '%s'.", + "Failover.errorConnectingToWriter": "An error occurred while trying to connect to the new writer '%s'.", + "Failover.errorSelectingReaderHost": "An error occurred while attempting to select a reader host candidate: '%s'. Candidates:", + "Failover.establishedConnection": "Connected to: '%s'.", + "Failover.failedReaderConnection": "[Reader Failover] Failed to connect to host: '%s'", + "Failover.failoverDisabled": "Cluster-aware failover is disabled.", + "Failover.failoverReaderTimeout": "The reader failover process was not able to establish a connection before timing out.", + "Failover.failoverReaderUnableToRefreshHostList": "The request to discover the new topology was unsuccessful.", + "Failover.noOperationsAfterConnectionClosed": "No operations allowed after connection closed.", + "Failover.noWriterHost": "Unable to find writer in updated host list:", + "Failover.parameterValue": "Failover parameter: %s=%s.", + "Failover.readerCandidateNil": "Reader candidate is and unable to be selected.", + "Failover.readerFailoverElapsed": "Reader failover elapsed in %v.", + "Failover.startReaderFailover": "Starting reader failover procedure.", + "Failover.startWriterFailover": "Starting writer failover procedure.", + "Failover.strictReaderUnknownHostRole": "Unable to determine host role for '%s'. Since failover mode is set to STRICT_READER and the host may be a writer, it will not be selected for reader failover.", + "Failover.timeoutError": "Internal failover task has timed out.", + "Failover.transactionResolutionUnknownError": "Transaction resolution unknown. Please re-configure session state if required and try restarting the transaction.", + "Failover.unableToConnect": "Unable to establish a SQL connection due to an unexpected error.", + "Failover.unableToConnectToReader": "Unable to establish SQL connection to the reader instance.", + "Failover.unableToRefreshHostList": "The request to discover the new topology timed out or was unsuccessful.", + "Failover.unexpectedReaderRole": "The new writer was identified to be '%s', but querying the instance for its role returned a role of '%s'.", + "Failover.writerFailoverElapsed": "Writer failover elapsed in %v.", + "FederatedAuthPlugin.unableToDetermineRegion": "Unable to determine connection region. If you are using a non-standard RDS URL, please set the '%s' property.", + "HighestWeightHostSelector.noHostsMatchingRole": "No available hosts were found matching the requested '%v' role.", + "HostInfoBuilder.InvalidEmptyHost": "HostInfoBuilder Host parameter must be set.", + "HostMonitoringConnectionPlugin.activatedMonitoring": "Executing method %v, monitoring is activated.", + "HostMonitoringConnectionPlugin.clusterHostInfoRequired": "Monitoring HostInfo is associated with a cluster endpoint, plugin needs to identify the cluster connection.", + "HostMonitoringConnectionPlugin.errorGettingMonitoringHostInfo": "Error occurred while getting monitoring HostInfo: %v.", + "HostMonitoringConnectionPlugin.errorIdentifyingConnection": "Error occurred while identifying connection: %v.", + "HostMonitoringConnectionPlugin.illegalArgumentError": "Illegal argument for '%v' was given to the HostMonitoringConnectionPlugin.", + "HostMonitoringConnectionPlugin.monitoringDeactivated": "Monitoring deactivated for method %v.", + "HostMonitoringConnectionPlugin.unableToIdentifyConnection": "Unable to identify the given connection %v : %v. Please ensure the correct host list provider is specified.", + "HostMonitoringConnectionPlugin.unavailableHost": "Host '%v' is unavailable.", + "HostMonitoringRoutine.detectedWriter": "Writer detected by host monitoring routine: '%s'.", + "HostMonitoringRoutine.routineCompleted": "Host monitoring routine completed in %v.", + "HostMonitoringRoutine.writerHostChanged": "Writer host changed from '%s' to host '%s'.", + "HostSelector.invalidHostWeightPairs": "The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of :. Weight values must be an integer greater than or equal to the default weight value of 1.", + "HostSelector.noHostsMatchingRole": "No available hosts were found matching the requested '%v' role.", + "IamAuthPlugin.connectionError": "Error occurred while opening a connection: '%v'.", + "IamAuthPlugin.errorGeneratingNewToken": "Error occurred while generating authentication token: '%v'.", + "IamAuthPlugin.errorGettingAwsCredentialsProvider": "Error occurred while getting aws.CredentialsProvider: '%v'.", + "IamAuthPlugin.unableToDetermineRegion": "Unable to determine connection region. If you are using a non-standard RDS URL, please set the '%v' property.", + "IamAuthPlugin.useCachedToken": "Using cached authentication token.", + "InternalPooledConn.CannotResetSession": "Could not reset session for internal pool conn. Closing the connection.", + "InternalPooledConn.UnsupportedOperation": "The underlying driver does not implement '%v'.", + "LimitlessPlugin.expectedShardGroupUrl": "The provided host was not a Limitless DB shard group URL: '%s'.", + "LimitlessPlugin.failedToConnectToHost": "Limitless Plugin failed to connect to host %v.", + "LimitlessPlugin.invalidDatabaseUrl": "Invalid Limitless Database URL '%v'. Please use a valid Limitless DB Shard Group endpoint URL.", + "LimitlessPlugin.unsupportedDialectOrDatabase": "Unsupported dialect '%T' encountered. Please ensure connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin.", + "LimitlessQueryHelperImpl.invalidDatabaseDialect": "Unable to fetch Limitless Routers due to invalid database dialect: '%T'. Please ensure that the connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin.", + "LimitlessQueryHelperImpl.invalidQuery": "Limitless Connection Plugin has encountered an error obtaining Limitless Router endpoints. Please ensure that you are connecting to an Aurora Limitless Database Shard Group Endpoint URL.", + "LimitlessQueryHelperImpl.invalidRouterLoad": "Invalid load metric value of %v from the transaction router query aurora_limitless_router_endpoints() for transaction router '%v'. The load metric value must be a decimal value between 0 and 1. Host weight will be assigned a default weight of 1.", + "LimitlessQueryHelperImpl.unableToFetchFetchingRouterLoad": "Unable to fetch Limitless Router Load.", + "LimitlessQueryHelperImpl.unableToFetchRouterName": "Unable to fetch Limitless Router Name.", + "LimitlessRouterMonitorImpl.closeMonitoring": "Closing Limitless monitoring for router: %v.", + "LimitlessRouterMonitorImpl.hostSelectorStrategyNotFound": "Limitless Router Monitor unable to fetch host selector strategy and update host selector strategy weights.", + "LimitlessRouterMonitorImpl.noRoutersFetched": "Unable to discover any Limitless transaction routers.", + "LimitlessRouterMonitorImpl.openedConnection": "Opened Limitless Router Monitor connection to %v.", + "LimitlessRouterMonitorImpl.openingConnection": "Opening Limitless Router Monitor connection to %v.", + "LimitlessRouterMonitorImpl.startMonitoring": "Starting Limitless monitoring for router: %v.", + "LimitlessRouterMonitorImpl.stopMonitoring": "Stopped Limitless monitoring for router: %v.", + "LimitlessRouterMonitorImpl.unableToCastHostSelector": "Limitless Router Monitor unable to cast fetched host selector strategy and update host selector weights.", + "LimitlessRouterServiceImpl.errorStartingMonitor": "An error occurred while starting Limitless Router Monitor: '%v'.", + "LimitlessRouterServiceImpl.failedToConnectToRouter": "Failed to connect to Limitless router: %v.", + "LimitlessRouterServiceImpl.limitlessRouterCacheEmpty": "Limitless Router cache is empty. This is normal during application start up when the cache is not yet populated.", + "LimitlessRouterServiceImpl.maxConnectRetriesExceeded": "Max Limitless connection retries has been exceeded. Unable to connect to any transaction routers.", + "LimitlessRouterServiceImpl.noRoutersAvailable": "Unable to connect to any Limitless transaction routers", + "LimitlessRouterServiceImpl.selectedHost": "Limitless Router '%v' has been selected.", + "LimitlessRouterServiceImpl.unableToConnectNoRoutersAvailable": "Unable to connect to original host %v. All transaction routers are unavailable. Please verify connection credentials and network connectivity.", + "MonitorImpl.hostAlive": "Host %v is alive.", + "MonitorImpl.hostDead": "Host %v is dead.", + "MonitorImpl.hostNotResponding": "Host %v is not responding.", + "MonitorImpl.monitorIsStopped": "Monitoring was already stopped for host %v.", + "MonitorImpl.openedMonitoringConnection": "Opened monitoring connection to %v.", + "MonitorImpl.openingMonitoringConnection": "Opening monitoring connection to %v.", + "MonitorImpl.startMonitoringRoutine": "Start monitoring routine for %v.", + "MonitorImpl.startMonitoringRoutineNewState": "Start monitoring routine checking for new states for %v.", + "MonitorImpl.stopMonitoringRoutine": "Stop monitoring routine for %v.", + "MonitorImpl.stopMonitoringRoutineNewState": "Stop monitoring routine checking for new states for %v.", + "MonitorImpl.stopped": "Stopped monitoring routine for host %v.", + "MonitorImpl.updatingActiveStates": "Updating active states for %v. Going from %v state(s) to %v.", + "MonitorServiceImpl.errorAbortingConn": "An error was thrown when aborting monitoring Conn: %v.", + "MonitorServiceImpl.illegalArgumentError": "Illegal argument for '%v' was given to MonitorServiceImpl.", + "OktaAuthPlugin.failedSamlAssertion": "Could not get SAML response from the following endpoint: '%v'.", + "OktaAuthPlugin.httpNon200StatusCode": "Did not get a 200 response from http request to '%v'. Received the following status code: '%v'.", + "OktaAuthPlugin.unableToDetermineRegion": "Unable to determine the region. Please set the '%v' parameter.", + "OktaAuthPlugin.unableToRetrieveSessionToken": "Could not retrieve session token from endpoint.", + "PluginManager.pipelineNone": "A pipeline was requested but the created pipeline evaluated to nil.", + "PluginManager.unknownPluginCode": "Unknown plugin code: '%v'.", + "PluginManagerImpl.invokedAgainstOldConnection": "The internal connection has changed since %v was created, skip executing method %v. This is likely due to failover. To ensure you are using the updated connection, please re-create Statement, Tx, Result and Row objects after failover.", + "PluginManagerImpl.releaseResources": "Releasing resources from PluginManagerImpl.", + "PluginManagerImpl.unsupportedHostSelectionStrategy": "The wrapper does not support the requested host selection strategy: %v.", + "PluginServiceImpl.failedToRetrieveHostPort": "Could not retrieve Host:Port for connection.", + "PluginServiceImpl.hostListEmpty": "Host list is empty.", + "PluginServiceImpl.hostsChangelistEmpty": "There are no changes in the hosts' availability.", + "PluginServiceImpl.initialHostNotSet": "Unable to update dialect, initial HostInfo has not been set.", + "PluginServiceImpl.nilConn": "Unable to set current connection, given connection evaluates to nil.", + "PluginServiceImpl.nilHost": "Current host evaluates to nil.", + "PluginServiceImpl.nonEmptyAliases": "FillAliases called when HostInfo already contains the following aliases: '%v'.", + "PluginServiceImpl.releaseResources": "Releasing resources from PluginServiceImpl.", + "PluginServiceImpl.requiredBlockingHostListProvider": "The detected host list provider is not a BlockingHostListProvider. A BlockingHostListProvider is required to force refresh the host list. Detected host list provider: '%T'.", + "PluginServiceImpl.setCurrentHost": "Set current host to '%v'.", + "RdsHostListProvider.givenTemplateInvalid": "The given cluster format template is invalid, using default host formatting.", + "RdsHostListProvider.suggestedClusterId": "ClusterId '%v' is suggested for url '%v'.", + "RdsHostListProvider.unableToGatherTopology": "Unable to gather topology, no hosts identified from query, cache, or initial host list.", + "RdsHostListProvider.unableToGetHostName": "Unable to retrieve host name from given connection.", + "RdsHostListProvider.unknownHostRole": "Query to gather host role failed, unable to determine role of current host.", + "ReadWriteSplittingPlugin.couldNotRefreshHostlist": "Could not refresh host list", + "ReadWriteSplittingPlugin.emptyHostList": "Host list is empty", + "ReadWriteSplittingPlugin.errorSwitchingToCachedReader": "An error occurred while trying to switch to a cached reader connection: '%s'. The driver will attempt to establish a new reader connection.", + "ReadWriteSplittingPlugin.errorSwitchingToReader": "An error occured while trying to switch to a reader connection: '%v'", + "ReadWriteSplittingPlugin.errorSwitchingToWriter": "An error occured while trying to switch to the writer connection", + "ReadWriteSplittingPlugin.errorVerifyingInitialHostRole": "An error occurred while obtaining the connected host's role. This could occur if the client is broken or if you are not connected to an Aurora database.", + "ReadWriteSplittingPlugin.errorWhileExecutingCommand": "[ReadWriteSplitting] Detected an error while executing a command: '%s'", + "ReadWriteSplittingPlugin.failedToConnectToReader": "Failed to connect to reader host: '%s'", + "ReadWriteSplittingPlugin.failoverErrorWhileExecutingCommand": "Detected a failover error while executing a command: '%s'", + "ReadWriteSplittingPlugin.fallbackToWriter": "Failed to switch to reader '%v'. The current writer '%s' will be used as fallback", + "ReadWriteSplittingPlugin.noReadersAvailable": "The plugin was unable to establish a reader connection to any reader instance", + "ReadWriteSplittingPlugin.noWriterFound": "No writer was found in the current host list. This may occur if the writer is not in the list of allowed hosts", + "ReadWriteSplittingPlugin.setReadOnlyFalseInTransaction": "readOnly was set to true during a transaction. Please complete the transaction before setting readOnly to true", + "ReadWriteSplittingPlugin.setReadOnlyOnClosedConnection": "Cannot set ReadOnly on closed connection", + "ReadWriteSplittingPlugin.settingCurrentConnection": "Setting the current connection to '%s'", + "ReadWriteSplittingPlugin.switchedFromReaderToWriter": "Switched from a reader to a writer host. New writer host: '%s'", + "ReadWriteSplittingPlugin.switchedFromWriterToReader": "Switched from a writer to a reader host. New reader host: '%s'", + "ReadWriteSplittingPlugin.unsupportedHostSelectorStrategy": "Unsupported host selection strategy '%s', specified in plugin configuration parameter '%s'. Please visit the Read/Write Splitting documentation for all supported strategies", + "ReadWriteSplittingPlugin.updateInternalConnectionInfoFailed": "Cannot update internal connection. Error: '%s'", + "SamlCredentialsProviderFactory.getSamlAssertionFailed": "Failed to get SAML Assertion due to exception: '%s'.", + "SessionStateService.logState": "Current session state: \n%s", + "SessionStateService.transferIncomplete": "Previous session state transfer has not been completed.", + "SlidingExpirationCache.exitingCacheCleanupRoutine": "Sliding expiration cache '%v' cleanup routine has been cancelled.", + "SlidingExpirationCache.itemDisposal": "Disposing of %s.", + "SlidingExpirationCache.startingCacheCleanupRoutine": "Sliding expiration cache '%v', has been initialized, cleanup routine has started.", + "StaleDnsHelper.clusterEndpointDns": "Cluster endpoint resolves to '%s'.", + "StaleDnsHelper.staleDnsDetected": "Stale DNS data detected. Opening a connection to '%s'.", + "StaleDnsHelper.writerHostInfo": "Writer host: '%s'.", + "StaleDnsHelper.writerInetAddress": "Writer host address: '%s'.", + "TargetDriverHelper.invalidProtocol": "Invalid database protocol was found: %v", + "TargetDriverHelper.missingDriver": "Cannot find the target driver for %v. Please ensure the target driver is imported and registered. Here is the list of registered drivers found: %v", + "TelemetryContext.castError": "Could not cast provided TelemetryContext to type '%s'.", + "Utils.failedToReadHttpResponse": "Failed to read HTTP response body: '%s'.", + "Utils.rollbackError": "Error occurred when attempting rollback: '%s'.", + "Utils.topology": "%s \n%s", + "WeightedRandomHostSelector.noHostsMatchingRole": "Weighted Random strategy was unable to select a host. No available hosts were found matching the requested '%v' role.", + "WeightedRandomHostSelector.unableToGetHost": "Weighted Random strategy was unable to select a host." } diff --git a/awssql/utils/dsn_parser.go b/awssql/utils/dsn_parser.go index 5baa1e90..99a1eee3 100644 --- a/awssql/utils/dsn_parser.go +++ b/awssql/utils/dsn_parser.go @@ -51,7 +51,10 @@ func GetHostsFromDsn(dsn string, isSingleWriterDsn bool) (hostInfoList []*host_i if err != nil { return hostInfoList, err } + return GetHostsFromProps(properties, isSingleWriterDsn) +} +func GetHostsFromProps(properties map[string]string, isSingleWriterDsn bool) (hostInfoList []*host_info_util.HostInfo, err error) { hostStringList := strings.Split(properties[property_util.HOST.Name], ",") portStringList := strings.Split(properties[property_util.PORT.Name], ",") port := host_info_util.HOST_NO_PORT diff --git a/awssql/utils/pair.go b/awssql/utils/pair.go new file mode 100644 index 00000000..fffae068 --- /dev/null +++ b/awssql/utils/pair.go @@ -0,0 +1,40 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package utils + +type Pair[T any, U any] struct { + left T + right U +} + +func NewPair[T any, U any](t T, u U) Pair[T, U] { + return Pair[T, U]{t, u} +} + +func (p *Pair[T, U]) GetLeft() (t T) { + if p == nil { + return + } + return p.left +} + +func (p *Pair[T, U]) GetRight() (u U) { + if p == nil { + return + } + return p.right +} diff --git a/awssql/utils/rds_utils.go b/awssql/utils/rds_utils.go index aea705a3..f8665823 100644 --- a/awssql/utils/rds_utils.go +++ b/awssql/utils/rds_utils.go @@ -58,6 +58,10 @@ var ( IP_V6_COMPRESSED_REGEXP = regexp.MustCompile("^(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)" + "::(([0-9A-Fa-f]{1,4}(:[0-9A-Fa-f]{1,4}){0,5})?)$") + BG_OLD_HOST_PATTERN = regexp.MustCompile("(?i).*(?P-old1\\.)..*") //nolint:all + BG_GREEN_HOSTID_PATTERN = regexp.MustCompile("(?i)(.*)-green-[0-9a-z]{6}") + BG_GREEN_HOST_PATTERN = regexp.MustCompile("(?i).*(?P-green-[0-9a-z]{6})..*") + dnsRegexpArray = [4]*regexp.Regexp{AURORA_DNS_PATTERN, AURORA_CHINA_DNS_PATTERN, AURORA_OLD_CHINA_DNS_PATTERN, AURORA_GOV_DNS_PATTERN} cachedDnsRegexp = sync.Map{} cachedDnsPatterns = sync.Map{} @@ -96,6 +100,10 @@ func IdentifyRdsUrlType(host string) RdsUrlType { } } +func IsIP(host string) bool { + return IsIPv4(host) || IsIPV6(host) +} + func IsIPv4(host string) bool { return host != "" && IP_V4_REGEXP.MatchString(host) } @@ -153,6 +161,51 @@ func IsRdsDns(host string) bool { return true } +func IsRdsInstance(host string) bool { + preparedHost := GetPreparedHost(host) + + return getDnsGroup(preparedHost) == "" && IsRdsDns(preparedHost) +} + +func IsGreenInstance(host string) bool { + preparedHost := GetPreparedHost(host) + return preparedHost != "" && BG_GREEN_HOST_PATTERN.MatchString(preparedHost) +} +func IsNotOldInstance(host string) bool { + preparedHost := GetPreparedHost(host) + return preparedHost == "" || !BG_OLD_HOST_PATTERN.MatchString(preparedHost) +} + +// IsNotGreenAndNotOldInstance Verify host contains neither green prefix nor old prefix. +func IsNotGreenAndNotOldInstance(host string) bool { + preparedHost := GetPreparedHost(host) + return preparedHost != "" && !BG_GREEN_HOST_PATTERN.MatchString(preparedHost) && !BG_OLD_HOST_PATTERN.MatchString(preparedHost) +} + +func RemoveGreenInstancePrefix(host string) string { + preparedHost := GetPreparedHost(host) + if preparedHost == "" { + return host + } + // First try the main green host pattern to extract the prefix + if matches := BG_GREEN_HOST_PATTERN.FindStringSubmatch(preparedHost); matches != nil { + prefixIndex := BG_GREEN_HOST_PATTERN.SubexpIndex("prefix") + if prefixIndex >= 0 && prefixIndex < len(matches) { + prefix := matches[prefixIndex] + if prefix != "" { + return strings.Replace(host, prefix+".", ".", 1) + } + } + return host + } + // Fallback to the hostid pattern for cases where the main pattern doesn't match + if matches := BG_GREEN_HOSTID_PATTERN.FindStringSubmatch(preparedHost); len(matches) > 1 { + return matches[1] // Return the captured group (everything before -green-[hash]) + } + + return host +} + func getDnsGroup(host string) string { if host == "" { return "" @@ -197,6 +250,19 @@ func GetRdsRegion(host string) string { return cachedDnsRegexp.FindStringSubmatch(host)[cachedDnsRegexp.SubexpIndex(REGION_GROUP)] } +func GetRdsClusterId(host string) string { + preparedHost := GetPreparedHost(host) + if preparedHost == "" { + return "" + } + + cachedDnsRegexp, ok := findAndCacheRegexp(preparedHost) + if ok && cachedDnsRegexp.FindStringSubmatch(host)[cachedDnsRegexp.SubexpIndex(REGION_GROUP)] != "" { + return cachedDnsRegexp.FindStringSubmatch(host)[cachedDnsRegexp.SubexpIndex(INSTANCE_GROUP)] + } + return "" +} + func findAndCacheRegexp(host string) (regexp.Regexp, bool) { val, ok := cachedDnsRegexp.Load(host) if ok && val != nil { diff --git a/awssql/utils/rw_map.go b/awssql/utils/rw_map.go new file mode 100644 index 00000000..c0ebaa85 --- /dev/null +++ b/awssql/utils/rw_map.go @@ -0,0 +1,151 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package utils + +import ( + "sync" +) + +type RWMap[T any] struct { + cache map[string]T + disposalFunc DisposalFunc[T] + lock sync.RWMutex +} + +func NewRWMap[T any]() *RWMap[T] { + return &RWMap[T]{ + cache: make(map[string]T), + } +} + +func NewRWMapWithDisposalFunc[T any](disposalFunc DisposalFunc[T]) *RWMap[T] { + return &RWMap[T]{ + cache: make(map[string]T), + disposalFunc: disposalFunc, + } +} + +func (c *RWMap[T]) Put(key string, value T) { + c.lock.Lock() + defer c.lock.Unlock() + val, ok := c.cache[key] + if ok && c.disposalFunc != nil { + c.disposalFunc(val) + } + c.cache[key] = value +} + +func (c *RWMap[T]) Get(key string) (T, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + + val, ok := c.cache[key] + return val, ok +} + +func (c *RWMap[T]) ComputeIfAbsent(key string, computeFunc func() T) T { + c.lock.Lock() + defer c.lock.Unlock() + + item, ok := c.cache[key] + if ok { + return item + } + + c.cache[key] = computeFunc() + return c.cache[key] +} + +func (c *RWMap[T]) PutIfAbsent(key string, value T) { + c.lock.Lock() + defer c.lock.Unlock() + _, ok := c.cache[key] + + if !ok { + c.cache[key] = value + } +} + +func (c *RWMap[T]) Remove(key string) { + c.lock.Lock() + val, ok := c.cache[key] + if ok { + if c.disposalFunc != nil { + c.disposalFunc(val) + } + delete(c.cache, key) + } + c.lock.Unlock() +} + +func (c *RWMap[T]) Clear() { + if c.disposalFunc != nil { + c.clearWithDisposalFunc() + return + } + + c.lock.Lock() + defer c.lock.Unlock() + for key := range c.cache { + delete(c.cache, key) + } +} + +func (c *RWMap[T]) clearWithDisposalFunc() { + c.lock.Lock() + defer c.lock.Unlock() + + for key, value := range c.cache { + c.disposalFunc(value) + delete(c.cache, key) + } +} + +// Returns a map copy of all entries in the cache. +func (c *RWMap[T]) GetAllEntries() map[string]T { + c.lock.RLock() + defer c.lock.RUnlock() + + entryMap := make(map[string]T) + for key, value := range c.cache { + entryMap[key] = value + } + return entryMap +} + +func (c *RWMap[T]) ReplaceCacheWithCopy(mapToCopy *RWMap[T]) { + entryMap := mapToCopy.GetAllEntries() + + c.lock.Lock() + defer c.lock.Unlock() + if c.disposalFunc != nil { + for _, value := range c.cache { + c.disposalFunc(value) + } + } + c.cache = entryMap +} + +func (c *RWMap[T]) Size() int { + if c == nil { + return 0 + } + c.lock.RLock() + defer c.lock.RUnlock() + + return len(c.cache) +} diff --git a/awssql/utils/utils.go b/awssql/utils/utils.go index 93d29297..23f45e29 100644 --- a/awssql/utils/utils.go +++ b/awssql/utils/utils.go @@ -22,7 +22,6 @@ import ( "database/sql/driver" "fmt" "log/slog" - "slices" "strconv" "strings" "sync" @@ -59,6 +58,7 @@ func FindHostInTopology(hosts []*host_info_util.HostInfo, hostNames ...string) * return nil } +// ExecQueryDirectly Directly executes query on conn. func ExecQueryDirectly(conn driver.Conn, query string) error { execerCtx, ok := conn.(driver.ExecerContext) if !ok { @@ -74,9 +74,8 @@ func ExecQueryDirectly(conn driver.Conn, query string) error { return nil } -// Directly executes query on conn, and returns the first row. -// Returns nil if unable to obtain a row. -func GetFirstRowFromQuery(conn driver.Conn, query string) []driver.Value { +// GetRowsFromQuery Directly executes query on conn, and returns the first n rows. +func GetRowsFromQuery(conn driver.Conn, query string, n int) [][]driver.Value { queryerCtx, ok := conn.(driver.QueryerContext) if !ok { // Unable to query, conn does not implement QueryerContext. @@ -84,24 +83,36 @@ func GetFirstRowFromQuery(conn driver.Conn, query string) []driver.Value { } rows, err := queryerCtx.QueryContext(context.Background(), query, nil) - if err != nil { + if err != nil || rows == nil { // Query failed. return nil } - if rows != nil { - defer rows.Close() + defer func(rows driver.Rows) { + _ = rows.Close() + }(rows) + + res := make([][]driver.Value, n) + row := make([]driver.Value, len(rows.Columns())) + for i := 0; i < n; i++ { + err = rows.Next(row) + if err != nil { + // Gathering row failed. + break + } + res[i] = row } + return res +} - res := make([]driver.Value, len(rows.Columns())) - err = rows.Next(res) - if err != nil { - // Gathering row failed. +func GetFirstRowFromQuery(conn driver.Conn, query string) []driver.Value { + res := GetRowsFromQuery(conn, query, 1) + if len(res) < 1 { return nil } - return res + return res[0] } -// Directly executes query on conn and converts all possible values in the first row to strings. +// GetFirstRowFromQueryAsString Directly executes query on conn and converts all possible values in the first row to strings. // Any values that cannot be converted are returned as "". Returns nil if unable to obtain a row. func GetFirstRowFromQueryAsString(conn driver.Conn, query string) []string { row := GetFirstRowFromQuery(conn, query) @@ -144,9 +155,39 @@ func FilterSlice[T any](slice []T, filter func(T) bool) []T { return result } -func SliceAndMapHaveCommonElement[T comparable, V any](sliceA []T, mapOfKeysAndValues map[T]V) bool { - for item := range mapOfKeysAndValues { - if slices.Contains(sliceA, item) { +func FilterSliceFindFirst[T any](slice []T, filter func(T) bool) T { + var zero T + for _, v := range slice { + if filter(v) { + return v + } + } + return zero +} + +func FilterSetFindFirst[T comparable, U any](set map[T]U, filter func(T) bool) T { + var zero T + for val := range set { + if filter(val) { + return val + } + } + return zero +} + +func FilterMapFindFirstValue[T comparable, U any](set map[T]U, filter func(U) bool) U { + var zero U + for _, val := range set { + if filter(val) { + return val + } + } + return zero +} + +func SliceAndMapHaveCommonElement[T comparable, V any](slice []T, mapOfKeysAndValues map[T]V) bool { + for _, item := range slice { + if _, exists := mapOfKeysAndValues[item]; exists { return true } } @@ -293,3 +334,13 @@ func GetSetReadOnlyFromCtx(ctx context.Context) bool { } return setReadOnly } + +func MySqlConvertValToString(value driver.Value) (string, bool) { + stringAsInt, ok := value.([]uint8) + return string(stringAsInt), ok +} + +func PgConvertValToString(value driver.Value) (string, bool) { + stringVal, ok := value.(string) + return stringVal, ok +} diff --git a/docs/contributor-guide/ContributorGuide.md b/docs/contributor-guide/ContributorGuide.md index f212d146..ff9c409d 100644 --- a/docs/contributor-guide/ContributorGuide.md +++ b/docs/contributor-guide/ContributorGuide.md @@ -42,7 +42,7 @@ There are specific benchmarks measuring the AWS Advanced Go Wrapper's plugin pip ![](../images/go_wrapper_execute_pipelines_benchmarks.png) -##### [Release Resources Pipeline](../contributor-guide/Pipelines.md#release-resources-pipeline) +##### Release Resources Pipeline ![](../images/go_wrapper_releaseresources_pipelines_benchmarks.png) @@ -65,7 +65,7 @@ The following diagrams show how the AWS Advanced Go Wrapper performs under a mor Common Failure Detection Setting | Parameter | Value | -| -------------------------- | -------- | +|----------------------------|----------| | `failoverTimeoutMs` | `120000` | | `failureDetectionTime` | `30000` | | `failureDetectionInterval` | `5000` | @@ -74,7 +74,7 @@ Common Failure Detection Setting Aggressive Failure Detection Setting | Parameter | Value | -| -------------------------- | -------- | +|----------------------------|----------| | `failoverTimeoutMs` | `120000` | | `failureDetectionTime` | `6000` | | `failureDetectionInterval` | `1000` | diff --git a/docs/contributor-guide/LoadablePlugins.md b/docs/contributor-guide/LoadablePlugins.md index dd812c2d..1893d239 100644 --- a/docs/contributor-guide/LoadablePlugins.md +++ b/docs/contributor-guide/LoadablePlugins.md @@ -68,12 +68,12 @@ Conn, Result, Rows, Stmt, and Tx; some examples are as follows: Plugins can also subscribe to the following pipelines: -| Pipeline | Method Name / Subscription Key | -|-----------------------------------------------------------------------------------------------------|:------------------------------:| -| [Host list provider pipeline](./Pipelines.md#host-list-provider-pipeline) | initHostProvider | -| [Connect pipeline](./Pipelines.md#connect-pipeline) | Conn.Connect | -| [Connection changed notification pipeline](./Pipelines.md#connection-changed-notification-pipeline) | notifyConnectionChanged | -| [Host list changed notification pipeline](./Pipelines.md#host-list-changed-notification-pipeline) | notifyHostListChanged | +| Pipeline | Method Name / Subscription Key | +|---------------------------------------------------------------------------------------------|:------------------------------:| +| [Host list provider pipeline](./Pipelines.md#hostlistprovider-pipeline) | initHostProvider | +| [Connect pipeline](./Pipelines.md#connect-pipeline) | Conn.Connect | +| [Connection changed notification pipeline](./Pipelines.md#notifyconnectionchanged-pipeline) | notifyConnectionChanged | +| [Host list changed notification pipeline](./Pipelines.md#notifyhostlistchanged-pipeline) | notifyHostListChanged | ### Tips on Creating a Custom Plugin diff --git a/federated-auth/federated_auth_plugin.go b/federated-auth/federated_auth_plugin.go index d31e8889..621357b2 100644 --- a/federated-auth/federated_auth_plugin.go +++ b/federated-auth/federated_auth_plugin.go @@ -35,7 +35,8 @@ import ( ) func init() { - awssql.UsePluginFactory("federatedAuth", NewFederatedAuthPluginFactory()) + awssql.UsePluginFactory(driver_infrastructure.ADFS_PLUGIN_CODE, + NewFederatedAuthPluginFactory()) } var TokenCache = utils.NewCache[string]() @@ -79,6 +80,10 @@ func NewFederatedAuthPlugin( }, nil } +func (f *FederatedAuthPlugin) GetPluginCode() string { + return driver_infrastructure.ADFS_PLUGIN_CODE +} + func (f *FederatedAuthPlugin) GetSubscribedMethods() []string { return []string{plugin_helpers.CONNECT_METHOD, plugin_helpers.FORCE_CONNECT_METHOD} } diff --git a/iam/iam_auth_plugin.go b/iam/iam_auth_plugin.go index 3822bca9..7484d820 100644 --- a/iam/iam_auth_plugin.go +++ b/iam/iam_auth_plugin.go @@ -36,7 +36,8 @@ import ( ) func init() { - awssql.UsePluginFactory("iam", NewIamAuthPluginFactory()) + awssql.UsePluginFactory(driver_infrastructure.IAM_PLUGIN_CODE, + NewIamAuthPluginFactory()) } type IamAuthPluginFactory struct{} @@ -63,6 +64,10 @@ type IamAuthPlugin struct { fetchTokenCounter telemetry.TelemetryCounter } +func (iamAuthPlugin *IamAuthPlugin) GetPluginCode() string { + return driver_infrastructure.IAM_PLUGIN_CODE +} + func NewIamAuthPlugin(pluginService driver_infrastructure.PluginService, iamTokenUtility auth_helpers.IamTokenUtility, props map[string]string) (*IamAuthPlugin, error) { fetchTokenCounter, err := pluginService.GetTelemetryFactory().CreateCounter("iam.fetchToken.count") if err != nil { diff --git a/okta/okta_auth_plugin.go b/okta/okta_auth_plugin.go index 8de5c596..b35977e0 100644 --- a/okta/okta_auth_plugin.go +++ b/okta/okta_auth_plugin.go @@ -35,7 +35,8 @@ import ( ) func init() { - awssql.UsePluginFactory("okta", NewOktaAuthPluginFactory()) + awssql.UsePluginFactory(driver_infrastructure.OKTA_PLUGIN_CODE, + NewOktaAuthPluginFactory()) } type OktaAuthPluginFactory struct{} @@ -83,6 +84,10 @@ func NewOktaAuthPlugin( }, nil } +func (o *OktaAuthPlugin) GetPluginCode() string { + return driver_infrastructure.OKTA_PLUGIN_CODE +} + func (o *OktaAuthPlugin) GetSubscribedMethods() []string { return []string{plugin_helpers.CONNECT_METHOD, plugin_helpers.FORCE_CONNECT_METHOD} }