diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 12b29f2f3076..e58acb29be84 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -555,14 +555,15 @@ func prepare4Join(testCase *hashJoinTestCase, innerExec, outerExec Executor) *Ha joinKeys = append(joinKeys, cols0[keyIdx]) } e := &HashJoinExec{ - baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, stringutil.StringerStr("HashJoin"), innerExec, outerExec), - concurrency: uint(testCase.concurrency), - joinType: 0, // InnerJoin - isOuterJoin: false, - innerKeys: joinKeys, - outerKeys: joinKeys, - innerExec: innerExec, - outerExec: outerExec, + baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, stringutil.StringerStr("HashJoin"), innerExec, outerExec), + concurrency: uint(testCase.concurrency), + joinType: 0, // InnerJoin + isOuterJoin: false, + innerKeys: joinKeys, + outerKeys: joinKeys, + innerExec: innerExec, + outerExec: outerExec, + innerStatsCount: float64(testCase.rows), } defaultValues := make([]types.Datum, e.innerExec.Schema().Len()) lhsTypes, rhsTypes := retTypes(innerExec), retTypes(outerExec) diff --git a/executor/builder.go b/executor/builder.go index 210245065e69..c57a1e7c4f62 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -991,10 +991,11 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo } e := &HashJoinExec{ - baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), leftExec, rightExec), - concurrency: v.Concurrency, - joinType: v.JoinType, - isOuterJoin: v.JoinType.IsOuterJoin(), + baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), leftExec, rightExec), + concurrency: v.Concurrency, + joinType: v.JoinType, + isOuterJoin: v.JoinType.IsOuterJoin(), + innerStatsCount: v.Children()[v.InnerChildIdx].StatsCount(), } defaultValues := v.DefaultValues diff --git a/executor/hash_table.go b/executor/hash_table.go index d2bbc7dd80da..22ca061b7933 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -38,11 +38,11 @@ type hashRowContainer struct { } func newHashRowContainer( - sc *stmtctx.StatementContext, + sc *stmtctx.StatementContext, statCount int, allTypes []*types.FieldType, keyColIdx []int, initCap, maxChunkSize int) *hashRowContainer { c := &hashRowContainer{ - hashTable: newRowHashMap(), + hashTable: newRowHashMapWithStatCount(statCount), sc: sc, allTypes: allTypes, keyColIdx: keyColIdx, @@ -207,14 +207,17 @@ type rowHashMap struct { } // newRowHashMap creates a new rowHashMap. -func newRowHashMap() *rowHashMap { +func newRowHashMapWithStatCount(statCount int) *rowHashMap { m := new(rowHashMap) - // TODO(fengliyuan): initialize the size of map from the estimated row count for better performance. - m.hashTable = make(map[uint64]entryAddr) + m.hashTable = make(map[uint64]entryAddr, statCount) m.entryStore.init() return m } +func newRowHashMap() *rowHashMap { + return newRowHashMapWithStatCount(0) +} + // Put puts the key/rowPtr pairs to the rowHashMap, multiple rowPtrs are stored in a list. func (m *rowHashMap) Put(hashKey uint64, rowPtr chunk.RowPtr) { oldEntryAddr := m.hashTable[hashKey] diff --git a/executor/join.go b/executor/join.go index a6c61f1abf43..e414f780ab4a 100644 --- a/executor/join.go +++ b/executor/join.go @@ -40,11 +40,12 @@ var ( type HashJoinExec struct { baseExecutor - outerExec Executor - innerExec Executor - outerFilter expression.CNFExprs - outerKeys []*expression.Column - innerKeys []*expression.Column + outerExec Executor + innerExec Executor + innerStatsCount float64 + outerFilter expression.CNFExprs + outerKeys []*expression.Column + innerKeys []*expression.Column // concurrency is the number of partition, build and join workers. concurrency uint @@ -492,13 +493,28 @@ func (e *HashJoinExec) fetchInnerAndBuildHashTable(ctx context.Context) { } } +const ( + // statCountMaxFactor defines the factor of maxStatCount with maxChunkSize. + // statCountMax is maxChunkSize * maxStatCountFactor. + // Set this threshold to prevent innerStatsCount being too large and causing a performance regression. + statCountMaxFactor = 10 * 1024 + + // statCountDivisor defines the divisor of innerStatsCount. + // Set this divisor to prevent innerStatsCount being too large and causing a performance regression. + statCountDivisor = 8 +) + // buildHashTableForList builds hash table from `list`. func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) error { innerKeyColIdx := make([]int, len(e.innerKeys)) for i := range e.innerKeys { innerKeyColIdx[i] = e.innerKeys[i].Index } - e.rowContainer = newHashRowContainer(e.ctx.GetSessionVars().StmtCtx, + statCount := int(e.innerStatsCount / statCountDivisor) + if statCount > e.maxChunkSize*statCountMaxFactor { + statCount = e.maxChunkSize * statCountMaxFactor + } + e.rowContainer = newHashRowContainer(e.ctx.GetSessionVars().StmtCtx, statCount, e.innerExec.base().retFieldTypes, innerKeyColIdx, e.initCap, e.maxChunkSize) e.rowContainer.GetMemTracker().AttachTo(e.memTracker)