Skip to content
This repository has been archived by the owner on Feb 6, 2024. It is now read-only.

fix: mod zero #171

Merged
merged 2 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/coordinator/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (f *Factory) makeCreateTableProcedure(ctx context.Context, request CreateTa
}
snapshot := request.ClusterMetadata.GetClusterSnapshot()

shards, err := f.shardPicker.PickShards(ctx, snapshot, 1, false)
shards, err := f.shardPicker.PickShards(ctx, snapshot, 1)
if err != nil {
log.Error("pick table shard", zap.Error(err))
return nil, errors.WithMessage(err, "pick table shard")
Expand Down Expand Up @@ -148,7 +148,7 @@ func (f *Factory) makeCreatePartitionTableProcedure(ctx context.Context, request
nodeNames[shardNode.NodeName] = 1
}

subTableShards, err := f.shardPicker.PickShards(ctx, snapshot, len(request.SourceReq.PartitionTableInfo.SubTableNames), true)
subTableShards, err := f.shardPicker.PickShards(ctx, snapshot, len(request.SourceReq.PartitionTableInfo.SubTableNames))
if err != nil {
return nil, errors.WithMessage(err, "pick sub table shards")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestCreatePartitionTable(t *testing.T) {
}

shardPicker := coordinator.NewRandomBalancedShardPicker()
subTableShards, err := shardPicker.PickShards(ctx, c.GetMetadata().GetClusterSnapshot(), len(request.GetPartitionTableInfo().SubTableNames), true)
subTableShards, err := shardPicker.PickShards(ctx, c.GetMetadata().GetClusterSnapshot(), len(request.GetPartitionTableInfo().SubTableNames))

shardNodesWithVersion := make([]metadata.ShardNodeWithVersion, 0, len(subTableShards))
for _, subTableShard := range subTableShards {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func testCreatePartitionTable(ctx context.Context, t *testing.T, dispatch eventd
Name: tableName,
}

subTableShards, err := shardPicker.PickShards(ctx, c.GetMetadata().GetClusterSnapshot(), len(request.GetPartitionTableInfo().SubTableNames), true)
subTableShards, err := shardPicker.PickShards(ctx, c.GetMetadata().GetClusterSnapshot(), len(request.GetPartitionTableInfo().SubTableNames))
re.NoError(err)

shardNodesWithVersion := make([]metadata.ShardNodeWithVersion, 0, len(subTableShards))
Expand Down
16 changes: 9 additions & 7 deletions server/coordinator/procedure/test/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ package test

import (
"context"
"crypto/rand"
"fmt"
"math/big"
"testing"

"github.com/CeresDB/ceresmeta/server/cluster"
Expand Down Expand Up @@ -123,13 +125,13 @@ func InitStableCluster(ctx context.Context, t *testing.T) *cluster.Cluster {
snapshot := c.GetMetadata().GetClusterSnapshot()
shardNodes := make([]storage.ShardNode, 0, DefaultShardTotal)
for _, shardView := range snapshot.Topology.ShardViewsMapping {
for _, node := range snapshot.RegisteredNodes {
shardNodes = append(shardNodes, storage.ShardNode{
ID: shardView.ShardID,
ShardRole: storage.ShardRoleLeader,
NodeName: node.Node.Name,
})
}
selectNodeIdx, err := rand.Int(rand.Reader, big.NewInt(int64(len(snapshot.RegisteredNodes))))
re.NoError(err)
shardNodes = append(shardNodes, storage.ShardNode{
ID: shardView.ShardID,
ShardRole: storage.ShardRoleLeader,
NodeName: snapshot.RegisteredNodes[selectNodeIdx.Int64()].Node.Name,
})
}

err := c.GetMetadata().UpdateClusterView(ctx, storage.ClusterStateStable, shardNodes)
Expand Down
90 changes: 46 additions & 44 deletions server/coordinator/shard_picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ import (

// ShardPicker is used to pick up the shards suitable for scheduling in the cluster.
// If expectShardNum bigger than cluster node number, the result depends on enableDuplicateNode:
// If enableDuplicateNode is false, pick shards will be failed and return error.
// If enableDuplicateNode is true, pick shard will return shards on the same node.
// TODO: Consider refactor this interface, abstracts the parameters of PickShards as PickStrategy.
type ShardPicker interface {
PickShards(ctx context.Context, snapshot metadata.Snapshot, expectShardNum int, enableDuplicateNode bool) ([]storage.ShardNode, error)
PickShards(ctx context.Context, snapshot metadata.Snapshot, expectShardNum int) ([]storage.ShardNode, error)
}

// RandomBalancedShardPicker randomly pick up shards that are not on the same node in the current cluster.
Expand All @@ -29,67 +27,71 @@ func NewRandomBalancedShardPicker() ShardPicker {
}

// PickShards will pick a specified number of shards as expectShardNum.
func (p *RandomBalancedShardPicker) PickShards(_ context.Context, snapshot metadata.Snapshot, expectShardNum int, enableDuplicateNode bool) ([]storage.ShardNode, error) {
func (p *RandomBalancedShardPicker) PickShards(_ context.Context, snapshot metadata.Snapshot, expectShardNum int) ([]storage.ShardNode, error) {
shardNodes := snapshot.Topology.ClusterView.ShardNodes

nodeShardsMapping := make(map[string][]storage.ShardNode, 0)
chunshao90 marked this conversation as resolved.
Show resolved Hide resolved
for _, shardNode := range shardNodes {
_, exists := nodeShardsMapping[shardNode.NodeName]
if !exists {
shardNodes := []storage.ShardNode{}
nodeShardsMapping[shardNode.NodeName] = shardNodes
nodeShardsMapping[shardNode.NodeName] = []storage.ShardNode{}
}
nodeShardsMapping[shardNode.NodeName] = append(nodeShardsMapping[shardNode.NodeName], shardNode)
}

if !enableDuplicateNode {
if len(nodeShardsMapping) < expectShardNum {
return nil, errors.WithMessagef(ErrNodeNumberNotEnough, "number of nodes is:%d, expecet number of shards is:%d", len(nodeShardsMapping), expectShardNum)
}
}

// Try to make shards on different nodes.
result := make([]storage.ShardNode, 0, expectShardNum)
totalShardLength := len(shardNodes)
tempNodeShardMapping := make(map[string][]storage.ShardNode, len(nodeShardsMapping))
for {
nodeNames := make([]string, 0, len(nodeShardsMapping))
for nodeName := range nodeShardsMapping {
nodeNames = append(nodeNames, nodeName)
nodeNames := make([]string, 0, len(nodeShardsMapping))
tempNodeShardMapping := copyNodeShardMapping(nodeShardsMapping)

for i := 0; i < expectShardNum; i++ {
// Initialize nodeNames.
if len(nodeNames) == 0 {
for nodeName := range nodeShardsMapping {
nodeNames = append(nodeNames, nodeName)
}
}

// Reset node shards when shard is all picked.
if len(result)%totalShardLength == 0 {
for nodeName, shardNode := range nodeShardsMapping {
tempShardNode := make([]storage.ShardNode, len(shardNode))
copy(tempShardNode, shardNode)
tempNodeShardMapping[nodeName] = tempShardNode
}
// Get random node.
selectNodeIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(nodeNames))))
if err != nil {
return nil, errors.WithMessage(err, "generate random node index")
}
nodeShards := tempNodeShardMapping[nodeNames[selectNodeIndex.Int64()]]

for len(nodeNames) > 0 {
if len(result) >= expectShardNum {
return result, nil
}
// When node shards is empty, copy from nodeShardsMapping and get shards again.
if len(nodeShards) == 0 {
tempNodeShardMapping = copyNodeShardMapping(nodeShardsMapping)

selectNodeIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(nodeNames))))
if err != nil {
return nil, errors.WithMessage(err, "generate random node index")
}
nodeShards = tempNodeShardMapping[nodeNames[selectNodeIndex.Int64()]]
}

nodeShards := tempNodeShardMapping[nodeNames[selectNodeIndex.Int64()]]
// Get random shard.
selectNodeShardsIndex, err := rand.Int(rand.Reader, big.NewInt(int64(len(nodeShards))))
if err != nil {
return nil, errors.WithMessage(err, "generate random node shard index")
}

if len(nodeShards) > 0 {
result = append(result, nodeShards[0])
result = append(result, nodeShards[selectNodeShardsIndex.Int64()])

// Remove select shard.
nodeShards[0] = nodeShards[len(nodeShards)-1]
tempNodeShardMapping[nodeNames[selectNodeIndex.Int64()]] = nodeShards[:len(nodeShards)-1]
}
// Remove select shard.
nodeShards[selectNodeShardsIndex.Int64()] = nodeShards[len(nodeShards)-1]
tempNodeShardMapping[nodeNames[selectNodeIndex.Int64()]] = nodeShards[:len(nodeShards)-1]

// Remove select node.
nodeNames[selectNodeIndex.Int64()] = nodeNames[len(nodeNames)-1]
nodeNames = nodeNames[:len(nodeNames)-1]
}
// Remove select node.
nodeNames[selectNodeIndex.Int64()] = nodeNames[len(nodeNames)-1]
nodeNames = nodeNames[:len(nodeNames)-1]
}

return result, nil
}

func copyNodeShardMapping(nodeShardsMapping map[string][]storage.ShardNode) map[string][]storage.ShardNode {
tempNodeShardMapping := make(map[string][]storage.ShardNode, len(nodeShardsMapping))
for nodeName, shardNode := range nodeShardsMapping {
tempShardNode := make([]storage.ShardNode, len(shardNode))
copy(tempShardNode, shardNode)
tempNodeShardMapping[nodeName] = tempShardNode
}
return tempNodeShardMapping
}
22 changes: 8 additions & 14 deletions server/coordinator/shard_picker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,19 @@ func TestRandomShardPicker(t *testing.T) {
snapshot := c.GetMetadata().GetClusterSnapshot()

shardPicker := coordinator.NewRandomBalancedShardPicker()
shardNodes, err := shardPicker.PickShards(ctx, snapshot, 2, false)
re.NoError(err)

// Verify the number of shards and ensure that they are not on the same node.
re.Equal(len(shardNodes), 2)
re.NotEqual(shardNodes[0].NodeName, shardNodes[1].NodeName)

// ExpectShardNum is bigger than node number and enableDuplicateNode is false, it should be throw error.
_, err = shardPicker.PickShards(ctx, snapshot, 3, false)
re.Error(err)

// ExpectShardNum is bigger than node number and enableDuplicateNode is true, it should return correct shards.
shardNodes, err = shardPicker.PickShards(ctx, snapshot, 3, true)
shardNodes, err := shardPicker.PickShards(ctx, snapshot, 3)
re.NoError(err)
re.Equal(len(shardNodes), 3)
shardNodes, err = shardPicker.PickShards(ctx, snapshot, 4, true)
shardNodes, err = shardPicker.PickShards(ctx, snapshot, 4)
re.NoError(err)
re.Equal(len(shardNodes), 4)
// ExpectShardNum is bigger than shard number.
_, err = shardPicker.PickShards(ctx, snapshot, 5, true)
shardNodes, err = shardPicker.PickShards(ctx, snapshot, 5)
re.NoError(err)
re.Equal(len(shardNodes), 5)
// TODO: Ensure that the shardNodes is average distributed on nodes and shards.
shardNodes, err = shardPicker.PickShards(ctx, snapshot, 9)
re.NoError(err)
re.Equal(len(shardNodes), 9)
chunshao90 marked this conversation as resolved.
Show resolved Hide resolved
}