diff --git a/server/api/shard_test.go b/server/api/shard_test.go index 39d9f91a..95b423c1 100644 --- a/server/api/shard_test.go +++ b/server/api/shard_test.go @@ -28,8 +28,12 @@ import ( "net/http/httptest" "strconv" "testing" + "time" + "github.com/apache/kvrocks-controller/config" + "github.com/apache/kvrocks-controller/controller" "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" "github.com/stretchr/testify/require" "github.com/apache/kvrocks-controller/consts" @@ -161,12 +165,22 @@ func TestShardBasics(t *testing.T) { func TestClusterFailover(t *testing.T) { ns := "test-ns" clusterName := "test-cluster-failover" - handler := &ShardHandler{s: store.NewClusterStore(engine.NewMock())} + clusterStore := store.NewClusterStore(engine.NewMock()) + handler := &ShardHandler{s: clusterStore} cluster, err := store.NewCluster(clusterName, []string{"127.0.0.1:7770", "127.0.0.1:7771"}, 2) require.NoError(t, err) node0, _ := cluster.Shards[0].Nodes[0].(*store.ClusterNode) node1, _ := cluster.Shards[0].Nodes[1].(*store.ClusterNode) + ctx := context.Background() + ctrl, err := controller.New(clusterStore, &config.ControllerConfig{ + FailOver: &config.FailOverConfig{MaxPingCount: 3, PingIntervalSeconds: 3}, + }) + require.NoError(t, err) + require.NoError(t, ctrl.Start(ctx)) + ctrl.WaitForReady() + defer ctrl.Close() + runFailover := func(t *testing.T, shardIndex, expectedStatusCode int) { recorder := httptest.NewRecorder() ctx := GetTestContext(recorder) @@ -191,6 +205,26 @@ func TestClusterFailover(t *testing.T) { }() require.NoError(t, handler.s.CreateCluster(ctx, ns, cluster)) + require.Eventually(t, func() bool { + // Confirm that the cluster info has been synced to each node + clusterInfo, err := node1.GetClusterInfo(ctx) + if err != nil { + return false + } + return clusterInfo.CurrentEpoch >= 1 + }, 10*time.Second, 100*time.Millisecond) + masterClient := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{node0.Addr()}, + }) + require.NoError(t, masterClient.Set(ctx, "a", 100, 0).Err()) + require.Eventually(t, func() bool { + slaveClient := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{node1.Addr()}, + ReadOnly: true, + }) + return slaveClient.Get(ctx, "a").Val() == "100" + }, 10*time.Second, 100*time.Millisecond) + runFailover(t, 0, http.StatusOK) }) diff --git a/store/cluster_mock_node.go b/store/cluster_mock_node.go index 35988ae5..1f72ae0f 100644 --- a/store/cluster_mock_node.go +++ b/store/cluster_mock_node.go @@ -39,7 +39,7 @@ func NewClusterMockNode() *ClusterMockNode { } func (mock *ClusterMockNode) GetClusterNodeInfo(ctx context.Context) (*ClusterNodeInfo, error) { - return &ClusterNodeInfo{Sequence: mock.Sequence}, nil + return &ClusterNodeInfo{Sequence: mock.Sequence, Role: mock.role}, nil } func (mock *ClusterMockNode) GetClusterInfo(ctx context.Context) (*ClusterInfo, error) {