diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 9dd3ba3daa8f..5972693e0d27 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -154,9 +154,9 @@ func buildMockDataSource(opt mockDataSourceParameters) *mockDataSource { colData[i] = m.genColDatums(i) } - m.genData = make([]*chunk.Chunk, (m.p.rows+m.initCap-1)/m.initCap) + m.genData = make([]*chunk.Chunk, (m.p.rows+m.maxChunkSize-1)/m.maxChunkSize) for i := range m.genData { - m.genData[i] = chunk.NewChunkWithCapacity(retTypes(m), m.ctx.GetSessionVars().MaxChunkSize) + m.genData[i] = chunk.NewChunkWithCapacity(retTypes(m), m.maxChunkSize) } for i := 0; i < m.p.rows; i++ { @@ -555,14 +555,15 @@ func prepare4Join(testCase *hashJoinTestCase, innerExec, outerExec Executor) *Ha joinKeys = append(joinKeys, cols0[keyIdx]) } e := &HashJoinExec{ - baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, stringutil.StringerStr("HashJoin"), innerExec, outerExec), - concurrency: uint(testCase.concurrency), - joinType: 0, // InnerJoin - isOuterJoin: false, - innerKeys: joinKeys, - outerKeys: joinKeys, - innerExec: innerExec, - outerExec: outerExec, + baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, stringutil.StringerStr("HashJoin"), innerExec, outerExec), + concurrency: uint(testCase.concurrency), + joinType: 0, // InnerJoin + isOuterJoin: false, + innerKeys: joinKeys, + outerKeys: joinKeys, + innerExec: innerExec, + outerExec: outerExec, + innerEstCount: float64(testCase.rows), } defaultValues := make([]types.Datum, e.innerExec.Schema().Len()) lhsTypes, rhsTypes := retTypes(innerExec), retTypes(outerExec) @@ -663,13 +664,13 @@ func benchmarkBuildHashTableForList(b *testing.B, casTest *hashJoinTestCase) { b.ResetTimer() for i := 0; i < b.N; i++ { b.StopTimer() - innerResultCh := make(chan *chunk.Chunk, 1) - go func() { - for _, chk := range dataSource1.genData { - innerResultCh <- chk - } - close(innerResultCh) - }() + exec.rowContainer = nil + exec.memTracker = memory.NewTracker(exec.id, exec.ctx.GetSessionVars().MemQuotaHashJoin) + innerResultCh := make(chan *chunk.Chunk, len(dataSource1.chunks)) + for _, chk := range dataSource1.chunks { + innerResultCh <- chk + } + close(innerResultCh) b.StartTimer() if err := exec.buildHashTableForList(innerResultCh); err != nil { @@ -690,4 +691,10 @@ func BenchmarkBuildHashTableForList(b *testing.B) { b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { benchmarkBuildHashTableForList(b, cas) }) + + cas.keyIdx = []int{0} + cas.rows = 10 + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkBuildHashTableForList(b, cas) + }) } diff --git a/executor/builder.go b/executor/builder.go index 210245065e69..a0d117dda0dc 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -991,10 +991,11 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo } e := &HashJoinExec{ - baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), leftExec, rightExec), - concurrency: v.Concurrency, - joinType: v.JoinType, - isOuterJoin: v.JoinType.IsOuterJoin(), + baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), leftExec, rightExec), + concurrency: v.Concurrency, + joinType: v.JoinType, + isOuterJoin: v.JoinType.IsOuterJoin(), + innerEstCount: v.Children()[v.InnerChildIdx].StatsCount(), } defaultValues := v.DefaultValues diff --git a/executor/hash_table.go b/executor/hash_table.go index af5693f236df..72efd12156b7 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -14,10 +14,166 @@ package executor import ( + "hash" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/memory" +) + +const ( + // estCountMaxFactor defines the factor of estCountMax with maxChunkSize. + // estCountMax is maxChunkSize * estCountMaxFactor, the maximum threshold of estCount. + // if estCount is larger than estCountMax, set estCount to estCountMax. + // Set this threshold to prevent innerEstCount being too large and causing a performance and memory regression. + estCountMaxFactor = 10 * 1024 + + // estCountMinFactor defines the factor of estCountMin with maxChunkSize. + // estCountMin is maxChunkSize * estCountMinFactor, the minimum threshold of estCount. + // If estCount is smaller than estCountMin, set estCount to 0. + // Set this threshold to prevent innerEstCount being too small and causing a performance regression. + estCountMinFactor = 8 + + // estCountDivisor defines the divisor of innerEstCount. + // Set this divisor to prevent innerEstCount being too large and causing a performance regression. + estCountDivisor = 8 ) -const maxEntrySliceLen = 8 * 1024 +// hashContext keeps the needed hash context of a db table in hash join. +type hashContext struct { + allTypes []*types.FieldType + keyColIdx []int + h hash.Hash64 + buf []byte +} + +// hashRowContainer handles the rows and the hash map of a table. +// TODO: support spilling out to disk when memory is limited. +type hashRowContainer struct { + records *chunk.List + hashTable *rowHashMap + + sc *stmtctx.StatementContext + hCtx *hashContext +} + +func newHashRowContainer(sctx sessionctx.Context, estCount int, hCtx *hashContext, initList *chunk.List) *hashRowContainer { + maxChunkSize := sctx.GetSessionVars().MaxChunkSize + // The estCount from cost model is not quite accurate and we need + // to avoid that it's too large to consume redundant memory. + // So I invent a rough protection, firstly divide it by estCountDivisor + // then set a maximum threshold and a minimum threshold. + estCount /= estCountDivisor + if estCount > maxChunkSize*estCountMaxFactor { + estCount = maxChunkSize * estCountMaxFactor + } + if estCount < maxChunkSize*estCountMinFactor { + estCount = 0 + } + c := &hashRowContainer{ + records: initList, + hashTable: newRowHashMap(estCount), + + sc: sctx.GetSessionVars().StmtCtx, + hCtx: hCtx, + } + return c +} + +func (c *hashRowContainer) GetMemTracker() *memory.Tracker { + return c.records.GetMemTracker() +} + +// GetMatchedRows get matched rows from probeRow. It can be called +// in multiple goroutines while each goroutine should keep its own +// h and buf. +func (c *hashRowContainer) GetMatchedRows(probeRow chunk.Row, hCtx *hashContext) (matched []chunk.Row, err error) { + hasNull, key, err := c.getJoinKeyFromChkRow(c.sc, probeRow, hCtx) + if err != nil || hasNull { + return + } + innerPtrs := c.hashTable.Get(key) + if len(innerPtrs) == 0 { + return + } + matched = make([]chunk.Row, 0, len(innerPtrs)) + for _, ptr := range innerPtrs { + matchedRow := c.records.GetRow(ptr) + var ok bool + ok, err = c.matchJoinKey(matchedRow, probeRow, hCtx) + if err != nil { + return + } + if !ok { + continue + } + matched = append(matched, matchedRow) + } + /* TODO(fengliyuan): add test case in this case + if len(matched) == 0 { + // noop + } + */ + return +} + +// matchJoinKey checks if join keys of buildRow and probeRow are logically equal. +func (c *hashRowContainer) matchJoinKey(buildRow, probeRow chunk.Row, probeHCtx *hashContext) (ok bool, err error) { + return codec.EqualChunkRow(c.sc, + buildRow, c.hCtx.allTypes, c.hCtx.keyColIdx, + probeRow, probeHCtx.allTypes, probeHCtx.keyColIdx) +} + +// PutChunk puts a chunk into hashRowContainer and build hash map. It's not thread-safe. +// key of hash table: hash value of key columns +// value of hash table: RowPtr of the corresponded row +func (c *hashRowContainer) PutChunk(chk *chunk.Chunk) error { + chkIdx := uint32(c.records.NumChunks()) + c.records.Add(chk) + var ( + hasNull bool + err error + key uint64 + ) + numRows := chk.NumRows() + for j := 0; j < numRows; j++ { + hasNull, key, err = c.getJoinKeyFromChkRow(c.sc, chk.GetRow(j), c.hCtx) + if err != nil { + return errors.Trace(err) + } + if hasNull { + continue + } + rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(j)} + c.hashTable.Put(key, rowPtr) + } + return nil +} + +// getJoinKeyFromChkRow fetches join keys from row and calculate the hash value. +func (*hashRowContainer) getJoinKeyFromChkRow(sc *stmtctx.StatementContext, row chunk.Row, hCtx *hashContext) (hasNull bool, key uint64, err error) { + for _, i := range hCtx.keyColIdx { + if row.IsNull(i) { + return true, 0, nil + } + } + hCtx.h.Reset() + err = codec.HashChunkRow(sc, hCtx.h, row, hCtx.allTypes, hCtx.keyColIdx, hCtx.buf) + return false, hCtx.h.Sum64(), err +} + +func (c hashRowContainer) Len() int { + return c.hashTable.Len() +} + +const ( + initialEntrySliceLen = 64 + maxEntrySliceLen = 8 * 1024 +) type entry struct { ptr chunk.RowPtr @@ -25,20 +181,32 @@ type entry struct { } type entryStore struct { - slices [][]entry - sliceIdx uint32 - sliceLen uint32 + slices [][]entry +} + +func (es *entryStore) init() { + es.slices = [][]entry{make([]entry, 0, initialEntrySliceLen)} + // Reserve the first empty entry, so entryAddr{} can represent nullEntryAddr. + reserved := es.put(entry{}) + if reserved != nullEntryAddr { + panic("entryStore: first entry is not nullEntryAddr") + } } func (es *entryStore) put(e entry) entryAddr { - if es.sliceLen == maxEntrySliceLen { - es.slices = append(es.slices, make([]entry, 0, maxEntrySliceLen)) - es.sliceLen = 0 - es.sliceIdx++ + sliceIdx := uint32(len(es.slices) - 1) + slice := es.slices[sliceIdx] + if len(slice) == cap(slice) { + size := cap(slice) * 2 + if size >= maxEntrySliceLen { + size = maxEntrySliceLen + } + slice = make([]entry, 0, size) + es.slices = append(es.slices, slice) + sliceIdx++ } - addr := entryAddr{sliceIdx: es.sliceIdx, offset: es.sliceLen} - es.slices[es.sliceIdx] = append(es.slices[es.sliceIdx], e) - es.sliceLen++ + addr := entryAddr{sliceIdx: sliceIdx, offset: uint32(len(slice))} + es.slices[sliceIdx] = append(slice, e) return addr } @@ -56,20 +224,19 @@ var nullEntryAddr = entryAddr{} // rowHashMap stores multiple rowPtr of rows for a given key with minimum GC overhead. // A given key can store multiple values. // It is not thread-safe, should only be used in one goroutine. +// TODO(fengliyuan): add unit test for this. type rowHashMap struct { entryStore entryStore hashTable map[uint64]entryAddr length int } -// newRowHashMap creates a new rowHashMap. -func newRowHashMap() *rowHashMap { +// newRowHashMap creates a new rowHashMap. estCount means the estimated size of the hashMap. +// If unknown, set it to 0. +func newRowHashMap(estCount int) *rowHashMap { m := new(rowHashMap) - // TODO(fengliyuan): initialize the size of map from the estimated row count for better performance. - m.hashTable = make(map[uint64]entryAddr) - m.entryStore.slices = [][]entry{make([]entry, 0, 64)} - // Reserve the first empty entry, so entryAddr{} can represent nullEntryAddr. - m.entryStore.put(entry{}) + m.hashTable = make(map[uint64]entryAddr, estCount) + m.entryStore.init() return m } diff --git a/executor/hash_table_test.go b/executor/hash_table_test.go new file mode 100644 index 000000000000..5478368ea9bb --- /dev/null +++ b/executor/hash_table_test.go @@ -0,0 +1,50 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/chunk" +) + +func (s *pkgTestSuite) TestRowHashMap(c *C) { + m := newRowHashMap(0) + m.Put(1, chunk.RowPtr{ChkIdx: 1, RowIdx: 1}) + c.Check(m.Get(1), DeepEquals, []chunk.RowPtr{{ChkIdx: 1, RowIdx: 1}}) + + rawData := map[uint64][]chunk.RowPtr{} + for i := uint64(0); i < 10; i++ { + for j := uint64(0); j < initialEntrySliceLen*i; j++ { + rawData[i] = append(rawData[i], chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(j)}) + } + } + m = newRowHashMap(0) + // put all rawData into m vertically + for j := uint64(0); j < initialEntrySliceLen*9; j++ { + for i := 9; i >= 0; i-- { + i := uint64(i) + if !(j < initialEntrySliceLen*i) { + break + } + m.Put(i, rawData[i][j]) + } + } + // check + totalCount := 0 + for i := uint64(0); i < 10; i++ { + totalCount += len(rawData[i]) + c.Check(m.Get(i), DeepEquals, rawData[i]) + } + c.Check(m.Len(), Equals, totalCount) +} diff --git a/executor/join.go b/executor/join.go index d626451f7487..053bf04a78e1 100644 --- a/executor/join.go +++ b/executor/join.go @@ -16,7 +16,6 @@ package executor import ( "context" "fmt" - "hash" "hash/fnv" "sync" "sync/atomic" @@ -25,10 +24,8 @@ import ( "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/expression" plannercore "github.com/pingcap/tidb/planner/core" - "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/stringutil" ) @@ -42,16 +39,16 @@ var ( type HashJoinExec struct { baseExecutor - outerExec Executor - innerExec Executor - outerFilter expression.CNFExprs - outerKeys []*expression.Column - innerKeys []*expression.Column + outerExec Executor + innerExec Executor + innerEstCount float64 + outerFilter expression.CNFExprs + outerKeys []*expression.Column + innerKeys []*expression.Column // concurrency is the number of partition, build and join workers. concurrency uint - hashTable *rowHashMap - joinKeyBuf [][]byte + rowContainer *hashRowContainer innerFinished chan error // joinWorkerWaitGroup is for sync multiple join workers. joinWorkerWaitGroup sync.WaitGroup @@ -65,9 +62,6 @@ type HashJoinExec struct { // execution, to avoid the concurrency of joiner.chk and joiner.selected. joiners []joiner - outerKeyColIdx []int - innerKeyColIdx []int - innerResult *chunk.List outerChkResourceCh chan *outerChkResource outerResultChs []chan *chunk.Chunk joinChkResourceCh []chan *chunk.Chunk @@ -141,10 +135,6 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.prepared = false e.memTracker = memory.NewTracker(e.id, e.ctx.GetSessionVars().MemQuotaHashJoin) e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) - e.joinKeyBuf = make([][]byte, e.concurrency) - for i := range e.joinKeyBuf { - e.joinKeyBuf[i] = make([]byte, 1) - } e.closeCh = make(chan struct{}) e.finished.Store(false) @@ -152,34 +142,6 @@ func (e *HashJoinExec) Open(ctx context.Context) error { return nil } -func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, h hash.Hash64, buf []byte) (hasNull bool, key uint64, err error) { - var keyColIdx []int - var allTypes []*types.FieldType - if isOuterKey { - keyColIdx = e.outerKeyColIdx - allTypes = retTypes(e.outerExec) - } else { - keyColIdx = e.innerKeyColIdx - allTypes = retTypes(e.innerExec) - } - - for _, i := range keyColIdx { - if row.IsNull(i) { - return true, 0, nil - } - } - h.Reset() - err = codec.HashChunkRow(e.ctx.GetSessionVars().StmtCtx, h, row, allTypes, keyColIdx, buf) - return false, h.Sum64(), err -} - -func (e *HashJoinExec) matchJoinKey(inner, outer chunk.Row) (ok bool, err error) { - innerAllTypes, outerAllTypes := retTypes(e.innerExec), retTypes(e.outerExec) - return codec.EqualChunkRow(e.ctx.GetSessionVars().StmtCtx, - inner, innerAllTypes, e.innerKeyColIdx, - outer, outerAllTypes, e.outerKeyColIdx) -} - // fetchOuterChunks get chunks from fetches chunks from the big table in a background goroutine // and sends the chunks to multiple channels which will be read by multiple join workers. func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { @@ -245,7 +207,7 @@ func (e *HashJoinExec) wait4Inner() (finished bool, err error) { return false, err } } - if e.hashTable.Len() == 0 && (e.joinType == plannercore.InnerJoin || e.joinType == plannercore.SemiJoin) { + if e.rowContainer.Len() == 0 && (e.joinType == plannercore.InnerJoin || e.joinType == plannercore.SemiJoin) { return true, nil } return false, nil @@ -257,9 +219,6 @@ var innerResultLabel fmt.Stringer = stringutil.StringerStr("innerResult") // and append them to e.innerResult. func (e *HashJoinExec) fetchInnerRows(ctx context.Context, chkCh chan<- *chunk.Chunk, doneCh <-chan struct{}) { defer close(chkCh) - e.innerResult = chunk.NewList(e.innerExec.base().retFieldTypes, e.initCap, e.maxChunkSize) - e.innerResult.GetMemTracker().AttachTo(e.memTracker) - e.innerResult.GetMemTracker().SetLabel(innerResultLabel) var err error for { if e.finished.Load().(bool) { @@ -280,7 +239,6 @@ func (e *HashJoinExec) fetchInnerRows(ctx context.Context, chkCh chan<- *chunk.C case <-e.closeCh: return case chkCh <- chk: - e.innerResult.Add(chk) } } } @@ -315,11 +273,6 @@ func (e *HashJoinExec) initializeForProbe() { // e.joinResultCh is for transmitting the join result chunks to the main // thread. e.joinResultCh = make(chan *hashjoinWorkerResult, e.concurrency+1) - - e.outerKeyColIdx = make([]int, len(e.outerKeys)) - for i := range e.outerKeys { - e.outerKeyColIdx[i] = e.outerKeys[i].Index - } } func (e *HashJoinExec) fetchOuterAndProbeHashTable(ctx context.Context) { @@ -327,12 +280,17 @@ func (e *HashJoinExec) fetchOuterAndProbeHashTable(ctx context.Context) { e.joinWorkerWaitGroup.Add(1) go util.WithRecovery(func() { e.fetchOuterChunks(ctx) }, e.handleOuterFetcherPanic) + outerKeyColIdx := make([]int, len(e.outerKeys)) + for i := range e.outerKeys { + outerKeyColIdx[i] = e.outerKeys[i].Index + } + // Start e.concurrency join workers to probe hash table and join inner and // outer rows. for i := uint(0); i < e.concurrency; i++ { e.joinWorkerWaitGroup.Add(1) workID := i - go util.WithRecovery(func() { e.runJoinWorker(workID) }, e.handleJoinWorkerPanic) + go util.WithRecovery(func() { e.runJoinWorker(workID, outerKeyColIdx) }, e.handleJoinWorkerPanic) } go util.WithRecovery(e.waitJoinWorkersAndCloseResultChan, nil) } @@ -359,11 +317,10 @@ func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() { close(e.joinResultCh) } -func (e *HashJoinExec) runJoinWorker(workerID uint) { +func (e *HashJoinExec) runJoinWorker(workerID uint, outerKeyColIdx []int) { var ( outerResult *chunk.Chunk selected = make([]bool, 0, chunk.InitialCapacity) - h = fnv.New64() ) ok, joinResult := e.getNewJoinResult(workerID) if !ok { @@ -374,6 +331,12 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { emptyOuterResult := &outerChkResource{ dest: e.outerResultChs[workerID], } + hCtx := &hashContext{ + allTypes: retTypes(e.outerExec), + keyColIdx: outerKeyColIdx, + h: fnv.New64(), + buf: make([]byte, 1), + } for ok := true; ok; { if e.finished.Load().(bool) { break @@ -386,7 +349,7 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { if !ok { break } - ok, joinResult = e.join2Chunk(workerID, outerResult, joinResult, selected, h) + ok, joinResult = e.join2Chunk(workerID, outerResult, hCtx, joinResult, selected) if !ok { break } @@ -401,36 +364,14 @@ func (e *HashJoinExec) runJoinWorker(workerID uint) { } } -func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.Row, - joinResult *hashjoinWorkerResult, h hash.Hash64) (bool, *hashjoinWorkerResult) { - hasNull, joinKey, err := e.getJoinKeyFromChkRow(true, outerRow, h, e.joinKeyBuf[workerID]) +func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.Row, hCtx *hashContext, + joinResult *hashjoinWorkerResult) (bool, *hashjoinWorkerResult) { + innerRows, err := e.rowContainer.GetMatchedRows(outerRow, hCtx) if err != nil { joinResult.err = err return false, joinResult } - if hasNull { - e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) - return true, joinResult - } - innerPtrs := e.hashTable.Get(joinKey) - if len(innerPtrs) == 0 { - e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) - return true, joinResult - } - innerRows := make([]chunk.Row, 0, len(innerPtrs)) - for _, ptr := range innerPtrs { - matchedInner := e.innerResult.GetRow(ptr) - ok, err := e.matchJoinKey(matchedInner, outerRow) - if err != nil { - joinResult.err = err - return false, joinResult - } - if !ok { - continue - } - innerRows = append(innerRows, matchedInner) - } - if len(innerRows) == 0 { // TODO(fengliyuan): add test case + if len(innerRows) == 0 { e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } @@ -472,8 +413,8 @@ func (e *HashJoinExec) getNewJoinResult(workerID uint) (bool, *hashjoinWorkerRes return ok, joinResult } -func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, joinResult *hashjoinWorkerResult, - selected []bool, h hash.Hash64) (ok bool, _ *hashjoinWorkerResult) { +func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, hCtx *hashContext, joinResult *hashjoinWorkerResult, + selected []bool) (ok bool, _ *hashjoinWorkerResult) { var err error selected, err = expression.VectorizedFilter(e.ctx, e.outerFilter, chunk.NewIterator4Chunk(outerChk), selected) if err != nil { @@ -484,7 +425,7 @@ func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, joinResu if !selected[i] { // process unmatched outer rows e.joiners[workerID].onMissMatch(false, outerChk.GetRow(i), joinResult.chk) } else { // process matched outer rows - ok, joinResult = e.joinMatchedOuterRow2Chunk(workerID, outerChk.GetRow(i), joinResult, h) + ok, joinResult = e.joinMatchedOuterRow2Chunk(workerID, outerChk.GetRow(i), hCtx, joinResult) if !ok { return false, joinResult } @@ -556,40 +497,30 @@ func (e *HashJoinExec) fetchInnerAndBuildHashTable(ctx context.Context) { } // buildHashTableForList builds hash table from `list`. -// key of hash table: hash value of key columns -// value of hash table: RowPtr of the corresponded row func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) error { - e.hashTable = newRowHashMap() - e.innerKeyColIdx = make([]int, len(e.innerKeys)) + innerKeyColIdx := make([]int, len(e.innerKeys)) for i := range e.innerKeys { - e.innerKeyColIdx[i] = e.innerKeys[i].Index - } - var ( - hasNull bool - err error - key uint64 - buf = make([]byte, 1) - ) - - h := fnv.New64() - chkIdx := uint32(0) + innerKeyColIdx[i] = e.innerKeys[i].Index + } + allTypes := e.innerExec.base().retFieldTypes + hCtx := &hashContext{ + allTypes: allTypes, + keyColIdx: innerKeyColIdx, + h: fnv.New64(), + buf: make([]byte, 1), + } + initList := chunk.NewList(allTypes, e.initCap, e.maxChunkSize) + e.rowContainer = newHashRowContainer(e.ctx, int(e.innerEstCount), hCtx, initList) + e.rowContainer.GetMemTracker().AttachTo(e.memTracker) + e.rowContainer.GetMemTracker().SetLabel(innerResultLabel) for chk := range innerResultCh { if e.finished.Load().(bool) { return nil } - numRows := chk.NumRows() - for j := 0; j < numRows; j++ { - hasNull, key, err = e.getJoinKeyFromChkRow(false, chk.GetRow(j), h, buf) - if err != nil { - return errors.Trace(err) - } - if hasNull { - continue - } - rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(j)} - e.hashTable.Put(key, rowPtr) + err := e.rowContainer.PutChunk(chk) + if err != nil { + return err } - chkIdx++ } return nil }