diff --git a/config/config.go b/config/config.go index adeb4cc1c623c..4e2370915564c 100644 --- a/config/config.go +++ b/config/config.go @@ -67,6 +67,7 @@ type Config struct { RunDDL bool `toml:"run-ddl" json:"run-ddl"` SplitTable bool `toml:"split-table" json:"split-table"` TokenLimit uint `toml:"token-limit" json:"token-limit"` + OOMUseTmpStorage bool `toml:"oom-use-tmp-storage" json:"oom-use-tmp-storage"` OOMAction string `toml:"oom-action" json:"oom-action"` MemQuotaQuery int64 `toml:"mem-quota-query" json:"mem-quota-query"` EnableStreaming bool `toml:"enable-streaming" json:"enable-streaming"` @@ -336,6 +337,7 @@ var defaultConf = Config{ SplitTable: true, Lease: "45s", TokenLimit: 1000, + OOMUseTmpStorage: true, OOMAction: "log", MemQuotaQuery: 32 << 30, EnableStreaming: false, diff --git a/config/config.toml.example b/config/config.toml.example index f4891850d7690..b309080794d19 100644 --- a/config/config.toml.example +++ b/config/config.toml.example @@ -31,13 +31,16 @@ split-table = true # The limit of concurrent executed sessions. token-limit = 1000 -# Only print a log when out of memory quota. -# Valid options: ["log", "cancel"] -oom-action = "log" - # Set the memory quota for a query in bytes. Default: 32GB mem-quota-query = 34359738368 +# Set to true to enable use of temporary disk for some executors when mem-quota-query is exceeded. +oom-use-tmp-storage = true + +# What to do when mem-quota-query is exceeded and can not be spilled over to disk any more. +# Valid options: ["log", "cancel"] +oom-action = "log" + # Enable coprocessor streaming. enable-streaming = false diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 82114306773d5..891a93f8d1536 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -21,6 +21,7 @@ import ( "strings" "testing" + "github.com/pingcap/log" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" @@ -34,6 +35,7 @@ import ( "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/stringutil" + "go.uber.org/zap/zapcore" ) var ( @@ -522,6 +524,7 @@ type hashJoinTestCase struct { concurrency int ctx sessionctx.Context keyIdx []int + disk bool } func (tc hashJoinTestCase) columns() []*expression.Column { @@ -532,8 +535,8 @@ func (tc hashJoinTestCase) columns() []*expression.Column { } func (tc hashJoinTestCase) String() string { - return fmt.Sprintf("(rows:%v, concurency:%v, joinKeyIdx: %v)", - tc.rows, tc.concurrency, tc.keyIdx) + return fmt.Sprintf("(rows:%v, concurency:%v, joinKeyIdx: %v, disk:%v)", + tc.rows, tc.concurrency, tc.keyIdx, tc.disk) } func defaultHashJoinTestCase() *hashJoinTestCase { @@ -572,6 +575,13 @@ func prepare4Join(testCase *hashJoinTestCase, innerExec, outerExec Executor) *Ha e.joiners[i] = newJoiner(testCase.ctx, e.joinType, true, defaultValues, nil, lhsTypes, rhsTypes) } + memLimit := int64(-1) + if testCase.disk { + memLimit = 1 + } + t := memory.NewTracker(stringutil.StringerStr("root of prepare4Join"), memLimit) + t.SetActionOnExceed(nil) + e.ctx.GetSessionVars().StmtCtx.MemTracker = t return e } @@ -620,10 +630,17 @@ func benchmarkHashJoinExecWithCase(b *testing.B, casTest *hashJoinTestCase) { b.Fatal(err) } b.StopTimer() + if exec.rowContainer.alreadySpilled() != casTest.disk { + b.Fatal("wrong usage with disk") + } } } func BenchmarkHashJoinExec(b *testing.B) { + lvl := log.GetLevel() + log.SetLevel(zapcore.ErrorLevel) + defer log.SetLevel(lvl) + b.ReportAllocs() cas := defaultHashJoinTestCase() b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { @@ -634,6 +651,19 @@ func BenchmarkHashJoinExec(b *testing.B) { b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { benchmarkHashJoinExecWithCase(b, cas) }) + + cas.keyIdx = []int{0} + cas.disk = true + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkHashJoinExecWithCase(b, cas) + }) + + cas.keyIdx = []int{0} + cas.disk = true + cas.rows = 1000 + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkHashJoinExecWithCase(b, cas) + }) } func benchmarkBuildHashTableForList(b *testing.B, casTest *hashJoinTestCase) { @@ -656,16 +686,16 @@ func benchmarkBuildHashTableForList(b *testing.B, casTest *hashJoinTestCase) { dataSource2 := buildMockDataSource(opt) dataSource1.prepareChunks() - exec := prepare4Join(casTest, dataSource1, dataSource2) - tmpCtx := context.Background() - if err := exec.Open(tmpCtx); err != nil { - b.Fatal(err) - } b.ResetTimer() for i := 0; i < b.N; i++ { b.StopTimer() - exec.rowContainer = nil - exec.memTracker = memory.NewTracker(exec.id, exec.ctx.GetSessionVars().MemQuotaHashJoin) + exec := prepare4Join(casTest, dataSource1, dataSource2) + tmpCtx := context.Background() + if err := exec.Open(tmpCtx); err != nil { + b.Fatal(err) + } + exec.prepared = true + innerResultCh := make(chan *chunk.Chunk, len(dataSource1.chunks)) for _, chk := range dataSource1.chunks { innerResultCh <- chk @@ -676,25 +706,37 @@ func benchmarkBuildHashTableForList(b *testing.B, casTest *hashJoinTestCase) { if err := exec.buildHashTableForList(innerResultCh); err != nil { b.Fatal(err) } + + if err := exec.Close(); err != nil { + b.Fatal(err) + } b.StopTimer() + if exec.rowContainer.alreadySpilled() != casTest.disk { + b.Fatal("wrong usage with disk") + } } } func BenchmarkBuildHashTableForList(b *testing.B) { + lvl := log.GetLevel() + log.SetLevel(zapcore.ErrorLevel) + defer log.SetLevel(lvl) + b.ReportAllocs() cas := defaultHashJoinTestCase() - b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { - benchmarkBuildHashTableForList(b, cas) - }) - - cas.keyIdx = []int{0} - b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { - benchmarkBuildHashTableForList(b, cas) - }) - - cas.keyIdx = []int{0} - cas.rows = 10 - b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { - benchmarkBuildHashTableForList(b, cas) - }) + rows := []int{10, 100000} + keyIdxs := [][]int{{0, 1}, {0}} + disks := []bool{false, true} + for _, row := range rows { + for _, keyIdx := range keyIdxs { + for _, disk := range disks { + cas.rows = row + cas.keyIdx = keyIdx + cas.disk = disk + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkBuildHashTableForList(b, cas) + }) + } + } + } } diff --git a/executor/hash_table.go b/executor/hash_table.go index 5b52c96b60b6e..6161d8802214a 100644 --- a/executor/hash_table.go +++ b/executor/hash_table.go @@ -16,6 +16,8 @@ package executor import ( "hash" "hash/fnv" + "sync" + "sync/atomic" "github.com/pingcap/errors" "github.com/pingcap/tidb/sessionctx" @@ -23,7 +25,10 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/disk" + "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/memory" + "go.uber.org/zap" ) const ( @@ -73,17 +78,34 @@ func (hc *hashContext) initHash(rows int) { } // hashRowContainer handles the rows and the hash map of a table. -// TODO: support spilling out to disk when memory is limited. type hashRowContainer struct { - records *chunk.List - hashTable *rowHashMap - sc *stmtctx.StatementContext hCtx *hashContext + + // hashTable stores the map of hashKey and RowPtr + hashTable *rowHashMap + + // memTracker is the reference of records.GetMemTracker(). + // records would be set to nil for garbage collection when spilling is activated + // so we need this reference. + memTracker *memory.Tracker + + // records stores the chunks in memory. + records *chunk.List + // recordsInDisk stores the chunks in disk. + recordsInDisk *chunk.ListInDisk + + // exceeded indicates that records have exceeded memQuota during + // this PutChunk and we should spill now. + // It's for concurrency usage, so access it with atomic. + exceeded uint32 + // spilled indicates that records have spilled out into disk. + // It's for concurrency usage, so access it with atomic. + spilled uint32 } -func newHashRowContainer(sctx sessionctx.Context, estCount int, hCtx *hashContext, initList *chunk.List) *hashRowContainer { - maxChunkSize := sctx.GetSessionVars().MaxChunkSize +func newHashRowContainer(sCtx sessionctx.Context, estCount int, hCtx *hashContext) *hashRowContainer { + maxChunkSize := sCtx.GetSessionVars().MaxChunkSize // The estCount from cost model is not quite accurate and we need // to avoid that it's too large to consume redundant memory. // So I invent a rough protection, firstly divide it by estCountDivisor @@ -95,18 +117,17 @@ func newHashRowContainer(sctx sessionctx.Context, estCount int, hCtx *hashContex if estCount < maxChunkSize*estCountMinFactor { estCount = 0 } + initList := chunk.NewList(hCtx.allTypes, maxChunkSize, maxChunkSize) c := &hashRowContainer{ - records: initList, - hashTable: newRowHashMap(estCount), - - sc: sctx.GetSessionVars().StmtCtx, + sc: sCtx.GetSessionVars().StmtCtx, hCtx: hCtx, + + hashTable: newRowHashMap(estCount), + memTracker: initList.GetMemTracker(), + records: initList, } - return c -} -func (c *hashRowContainer) GetMemTracker() *memory.Tracker { - return c.records.GetMemTracker() + return c } // GetMatchedRows get matched rows from probeRow. It can be called @@ -122,8 +143,16 @@ func (c *hashRowContainer) GetMatchedRows(probeRow chunk.Row, hCtx *hashContext) return } matched = make([]chunk.Row, 0, len(innerPtrs)) + var matchedRow chunk.Row for _, ptr := range innerPtrs { - matchedRow := c.records.GetRow(ptr) + if c.alreadySpilled() { + matchedRow, err = c.recordsInDisk.GetRow(ptr) + if err != nil { + return + } + } else { + matchedRow = c.records.GetRow(ptr) + } var ok bool ok, err = c.matchJoinKey(matchedRow, probeRow, hCtx) if err != nil { @@ -134,11 +163,6 @@ func (c *hashRowContainer) GetMatchedRows(probeRow chunk.Row, hCtx *hashContext) } matched = append(matched, matchedRow) } - /* TODO(fengliyuan): add test case in this case - if len(matched) == 0 { - // noop - } - */ return } @@ -149,14 +173,51 @@ func (c *hashRowContainer) matchJoinKey(buildRow, probeRow chunk.Row, probeHCtx probeRow, probeHCtx.allTypes, probeHCtx.keyColIdx) } +func (c *hashRowContainer) spillToDisk() (err error) { + N := c.records.NumChunks() + c.recordsInDisk = chunk.NewListInDisk(c.hCtx.allTypes) + for i := 0; i < N; i++ { + chk := c.records.GetChunk(i) + err = c.recordsInDisk.Add(chk) + if err != nil { + return + } + } + return +} + +// alreadySpilled indicates that records have spilled out into disk. +func (c *hashRowContainer) alreadySpilled() bool { return c.recordsInDisk != nil } + +// alreadySpilledSafe indicates that records have spilled out into disk. It's thread-safe. +func (c *hashRowContainer) alreadySpilledSafe() bool { return atomic.LoadUint32(&c.spilled) == 1 } + // PutChunk puts a chunk into hashRowContainer and build hash map. It's not thread-safe. // key of hash table: hash value of key columns // value of hash table: RowPtr of the corresponded row func (c *hashRowContainer) PutChunk(chk *chunk.Chunk) error { - chkIdx := uint32(c.records.NumChunks()) + var chkIdx uint32 + if c.alreadySpilled() { + // append chk to disk. + chkIdx = uint32(c.recordsInDisk.NumChunks()) + err := c.recordsInDisk.Add(chk) + if err != nil { + return err + } + } else { + chkIdx = uint32(c.records.NumChunks()) + c.records.Add(chk) + if atomic.LoadUint32(&c.exceeded) != 0 { + err := c.spillToDisk() + if err != nil { + return err + } + c.records = nil // GC its internal chunks. + c.memTracker.Consume(-c.memTracker.BytesConsumed()) + atomic.StoreUint32(&c.spilled, 1) + } + } numRows := chk.NumRows() - - c.records.Add(chk) c.hCtx.initHash(numRows) hCtx := c.hCtx @@ -189,10 +250,58 @@ func (*hashRowContainer) getJoinKeyFromChkRow(sc *stmtctx.StatementContext, row return false, hCtx.hashVals[0].Sum64(), err } +// Len returns the length of the records in hashRowContainer. func (c hashRowContainer) Len() int { return c.hashTable.Len() } +func (c *hashRowContainer) Close() error { + if c.recordsInDisk != nil { + return c.recordsInDisk.Close() + } + return nil +} + +// GetMemTracker returns the underlying memory usage tracker in hashRowContainer. +func (c *hashRowContainer) GetMemTracker() *memory.Tracker { return c.memTracker } + +// GetDiskTracker returns the underlying disk usage tracker in hashRowContainer. +func (c *hashRowContainer) GetDiskTracker() *disk.Tracker { return c.recordsInDisk.GetDiskTracker() } + +// ActionSpill returns a memory.ActionOnExceed for spilling over to disk. +func (c *hashRowContainer) ActionSpill() memory.ActionOnExceed { + return &spillDiskAction{c: c} +} + +// spillDiskAction implements memory.ActionOnExceed for chunk.List. If +// the memory quota of a query is exceeded, spillDiskAction.Action is +// triggered. +type spillDiskAction struct { + once sync.Once + c *hashRowContainer + fallbackAction memory.ActionOnExceed +} + +// Action sends a signal to trigger spillToDisk method of hashRowContainer +// and if it is already triggered before, call its fallbackAction. +func (a *spillDiskAction) Action(t *memory.Tracker) { + if a.c.alreadySpilledSafe() { + if a.fallbackAction != nil { + a.fallbackAction.Action(t) + } + } + a.once.Do(func() { + atomic.StoreUint32(&a.c.exceeded, 1) + logutil.BgLogger().Info("memory exceeds quota, spill to disk now.", zap.String("memory", t.String())) + }) +} + +func (a *spillDiskAction) SetFallback(fallback memory.ActionOnExceed) { + a.fallbackAction = fallback +} + +func (a *spillDiskAction) SetLogHook(hook func(uint64)) {} + const ( initialEntrySliceLen = 64 maxEntrySliceLen = 8 * 1024 @@ -247,7 +356,6 @@ var nullEntryAddr = entryAddr{} // rowHashMap stores multiple rowPtr of rows for a given key with minimum GC overhead. // A given key can store multiple values. // It is not thread-safe, should only be used in one goroutine. -// TODO(fengliyuan): add unit test for this. type rowHashMap struct { entryStore entryStore hashTable map[uint64]entryAddr diff --git a/executor/hash_table_test.go b/executor/hash_table_test.go index 5478368ea9bbe..75a78c881a134 100644 --- a/executor/hash_table_test.go +++ b/executor/hash_table_test.go @@ -14,8 +14,16 @@ package executor import ( + "fmt" + "hash" + "hash/fnv" + . "github.com/pingcap/check" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/mock" ) func (s *pkgTestSuite) TestRowHashMap(c *C) { @@ -48,3 +56,124 @@ func (s *pkgTestSuite) TestRowHashMap(c *C) { } c.Check(m.Len(), Equals, totalCount) } + +func initBuildChunk(numRows int) (*chunk.Chunk, []*types.FieldType) { + numCols := 6 + colTypes := make([]*types.FieldType, 0, numCols) + colTypes = append(colTypes, &types.FieldType{Tp: mysql.TypeLonglong}) + colTypes = append(colTypes, &types.FieldType{Tp: mysql.TypeLonglong}) + colTypes = append(colTypes, &types.FieldType{Tp: mysql.TypeVarchar}) + colTypes = append(colTypes, &types.FieldType{Tp: mysql.TypeVarchar}) + colTypes = append(colTypes, &types.FieldType{Tp: mysql.TypeNewDecimal}) + colTypes = append(colTypes, &types.FieldType{Tp: mysql.TypeJSON}) + + oldChk := chunk.NewChunkWithCapacity(colTypes, numRows) + for i := 0; i < numRows; i++ { + str := fmt.Sprintf("%d.12345", i) + oldChk.AppendNull(0) + oldChk.AppendInt64(1, int64(i)) + oldChk.AppendString(2, str) + oldChk.AppendString(3, str) + oldChk.AppendMyDecimal(4, types.NewDecFromStringForTest(str)) + oldChk.AppendJSON(5, json.CreateBinary(str)) + } + return oldChk, colTypes +} + +func initProbeChunk(numRows int) (*chunk.Chunk, []*types.FieldType) { + numCols := 3 + colTypes := make([]*types.FieldType, 0, numCols) + colTypes = append(colTypes, &types.FieldType{Tp: mysql.TypeLonglong}) + colTypes = append(colTypes, &types.FieldType{Tp: mysql.TypeLonglong}) + colTypes = append(colTypes, &types.FieldType{Tp: mysql.TypeVarchar}) + + oldChk := chunk.NewChunkWithCapacity(colTypes, numRows) + for i := 0; i < numRows; i++ { + str := fmt.Sprintf("%d.12345", i) + oldChk.AppendNull(0) + oldChk.AppendInt64(1, int64(i)) + oldChk.AppendString(2, str) + } + return oldChk, colTypes +} + +type hashCollision struct { + count int +} + +func (h *hashCollision) Sum64() uint64 { + h.count++ + return 0 +} +func (h hashCollision) Write(p []byte) (n int, err error) { return len(p), nil } +func (h hashCollision) Reset() {} +func (h hashCollision) Sum(b []byte) []byte { panic("not implemented") } +func (h hashCollision) Size() int { panic("not implemented") } +func (h hashCollision) BlockSize() int { panic("not implemented") } + +func (s *pkgTestSuite) TestHashRowContainer(c *C) { + hashFunc := func() hash.Hash64 { + return fnv.New64() + } + s.testHashRowContainer(c, hashFunc, false) + s.testHashRowContainer(c, hashFunc, true) + + h := &hashCollision{count: 0} + hashFuncCollision := func() hash.Hash64 { + return h + } + s.testHashRowContainer(c, hashFuncCollision, false) + c.Assert(h.count > 0, IsTrue) +} + +func (s *pkgTestSuite) testHashRowContainer(c *C, hashFunc func() hash.Hash64, spill bool) { + sctx := mock.NewContext() + var err error + numRows := 10 + + chk0, colTypes := initBuildChunk(numRows) + chk1, _ := initBuildChunk(numRows) + + hCtx := &hashContext{ + allTypes: colTypes, + keyColIdx: []int{1, 2}, + } + hCtx.hasNull = make([]bool, numRows) + for i := 0; i < numRows; i++ { + hCtx.hashVals = append(hCtx.hashVals, hashFunc()) + } + rowContainer := newHashRowContainer(sctx, 0, hCtx) + tracker := rowContainer.GetMemTracker() + tracker.SetLabel(innerResultLabel) + if spill { + rowContainer.ActionSpill().Action(tracker) + tracker.SetBytesLimit(1) + } + err = rowContainer.PutChunk(chk0) + c.Assert(err, IsNil) + err = rowContainer.PutChunk(chk1) + c.Assert(err, IsNil) + + c.Assert(rowContainer.alreadySpilled(), Equals, spill) + c.Assert(rowContainer.alreadySpilledSafe(), Equals, spill) + c.Assert(rowContainer.GetMemTracker().BytesConsumed() == 0, Equals, spill) + c.Assert(rowContainer.GetMemTracker().BytesConsumed() > 0, Equals, !spill) + if rowContainer.alreadySpilled() { + c.Assert(rowContainer.GetDiskTracker(), NotNil) + c.Assert(rowContainer.GetDiskTracker().BytesConsumed() > 0, Equals, true) + } + + probeChk, probeColType := initProbeChunk(2) + probeRow := probeChk.GetRow(1) + probeCtx := &hashContext{ + allTypes: probeColType, + keyColIdx: []int{1, 2}, + } + probeCtx.hasNull = make([]bool, 1) + probeCtx.hashVals = append(hCtx.hashVals, hashFunc()) + matched, err := rowContainer.GetMatchedRows(probeRow, probeCtx) + c.Assert(err, IsNil) + c.Assert(len(matched), Equals, 2) + c.Assert(matched[0].GetDatumRow(colTypes), DeepEquals, chk0.GetRow(1).GetDatumRow(colTypes)) + c.Assert(matched[1].GetDatumRow(colTypes), DeepEquals, chk1.GetRow(1).GetDatumRow(colTypes)) +} diff --git a/executor/join.go b/executor/join.go index 480ede9a0e5d3..d1cf41533d227 100644 --- a/executor/join.go +++ b/executor/join.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/expression" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/util" @@ -118,8 +119,8 @@ func (e *HashJoinExec) Close() error { } e.outerChkResourceCh = nil e.joinChkResourceCh = nil + terror.Call(e.rowContainer.Close) } - e.memTracker = nil err := e.baseExecutor.Close() return err @@ -132,7 +133,7 @@ func (e *HashJoinExec) Open(ctx context.Context) error { } e.prepared = false - e.memTracker = memory.NewTracker(e.id, e.ctx.GetSessionVars().MemQuotaHashJoin) + e.memTracker = memory.NewTracker(e.id, -1) e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) e.closeCh = make(chan struct{}) @@ -212,7 +213,7 @@ func (e *HashJoinExec) wait4Inner() (finished bool, err error) { return false, nil } -var innerResultLabel fmt.Stringer = stringutil.StringerStr("innerResult") +var innerResultLabel fmt.Stringer = stringutil.StringerStr("hashJoin.innerResult") // fetchInnerRows fetches all rows from inner executor, // and append them to e.innerResult. @@ -504,10 +505,13 @@ func (e *HashJoinExec) buildHashTableForList(innerResultCh <-chan *chunk.Chunk) allTypes: allTypes, keyColIdx: innerKeyColIdx, } - initList := chunk.NewList(allTypes, e.initCap, e.maxChunkSize) - e.rowContainer = newHashRowContainer(e.ctx, int(e.innerEstCount), hCtx, initList) + e.rowContainer = newHashRowContainer(e.ctx, int(e.innerEstCount), hCtx) e.rowContainer.GetMemTracker().AttachTo(e.memTracker) e.rowContainer.GetMemTracker().SetLabel(innerResultLabel) + if config.GetGlobalConfig().OOMUseTmpStorage { + actionSpill := e.rowContainer.ActionSpill() + e.ctx.GetSessionVars().StmtCtx.MemTracker.FallbackOldAndSetNewAction(actionSpill) + } for chk := range innerResultCh { if e.finished.Load().(bool) { return nil diff --git a/executor/join_test.go b/executor/join_test.go index 5b838260e2b33..6dbbeb4edf56f 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -20,7 +20,9 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/testkit" ) @@ -35,6 +37,34 @@ func (s *testSuite2) TestJoinPanic(c *C) { c.Check(err, NotNil) } +func (s *testSuite2) TestJoinInDisk(c *C) { + originCfg := config.GetGlobalConfig() + newConf := config.NewConfig() + newConf.OOMUseTmpStorage = true + config.StoreGlobalConfig(newConf) + defer config.StoreGlobalConfig(originCfg) + + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + sm := &mockSessionManager1{ + PS: make([]*util.ProcessInfo, 0), + } + tk.Se.SetSessionManager(sm) + s.domain.ExpensiveQueryHandle().SetSessionManager(sm) + + // TODO(fengliyuan): how to ensure that it is using disk really? + tk.MustExec("set @@tidb_mem_quota_query=1;") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 int, c2 int)") + tk.MustExec("create table t1(c1 int, c2 int)") + tk.MustExec("insert into t values(1,1),(2,2)") + tk.MustExec("insert into t1 values(2,3),(4,4)") + result := tk.MustQuery("select /*+ TIDB_HJ(t, t2) */ * from t, t1 where t.c1 = t1.c1") + result.Check(testkit.Rows("2 2 2 3")) +} + func (s *testSuite2) TestJoin(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/executor/joiner.go b/executor/joiner.go index af8faef85050f..22ceccccf2ee5 100644 --- a/executor/joiner.go +++ b/executor/joiner.go @@ -53,8 +53,8 @@ type joiner interface { // rows are appended to `chk`. The size of `chk` is limited to MaxChunkSize. // Note that when the outer row is considered unmatched, we need to differentiate // whether the join conditions return null or false, because that matters for - // AntiSemiJoin/LeftOuterSemiJoin/AntiLeftOuterSemiJoin, and the result is reflected - // by the second return value; for other join types, we always return false. + // AntiSemiJoin/LeftOuterSemiJoin/AntiLeftOuterSemiJoin, by setting the return + // value isNull; for other join types, isNull is always false. // // NOTE: Callers need to call this function multiple times to consume all // the inner rows for an outer row, and decide whether the outer row can be diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 72f29a36f5ef8..45521b5494da2 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -277,7 +277,7 @@ type PhysicalIndexJoin struct { // IdxColLens stores the length of each index column. IdxColLens []int // CompareFilters stores the filters for last column if those filters need to be evaluated during execution. - // e.g. select * from t where t.a = t1.a and t.b > t1.b and t.b < t1.b+10 + // e.g. select * from t, t1 where t.a = t1.a and t.b > t1.b and t.b < t1.b+10 // If there's index(t.a, t.b). All the filters can be used to construct index range but t.b > t1.b and t.b < t1.b=10 // need to be evaluated after we fetch the data of t1. // This struct stores them and evaluate them to ranges. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index c7ad86f29e5e0..8680493809d32 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -992,6 +992,8 @@ type Concurrency struct { type MemQuota struct { // MemQuotaQuery defines the memory quota for a query. MemQuotaQuery int64 + + // TODO: remove them below sometime, it should have only one Quota(MemQuotaQuery). // MemQuotaHashJoin defines the memory quota for a hash join executor. MemQuotaHashJoin int64 // MemQuotaMergeJoin defines the memory quota for a merge join executor. diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index c1c9c711ff82f..b489bc56face2 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -93,7 +93,8 @@ const ( // "tidb_mem_quota_indexlookupreader": control the memory quota of "IndexLookUpExecutor". // "tidb_mem_quota_indexlookupjoin": control the memory quota of "IndexLookUpJoin". // "tidb_mem_quota_nestedloopapply": control the memory quota of "NestedLoopApplyExec". - TIDBMemQuotaQuery = "tidb_mem_quota_query" // Bytes. + TIDBMemQuotaQuery = "tidb_mem_quota_query" // Bytes. + // TODO: remove them below sometime, it should have only one Quota(TIDBMemQuotaQuery). TIDBMemQuotaHashJoin = "tidb_mem_quota_hashjoin" // Bytes. TIDBMemQuotaMergeJoin = "tidb_mem_quota_mergejoin" // Bytes. TIDBMemQuotaSort = "tidb_mem_quota_sort" // Bytes. diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index 9ee35c7bf3ae6..ed0183db5ad12 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -70,7 +70,7 @@ func (s *testVarsutilSuite) TestNewSessionVars(c *C) { c.Assert(vars.DistSQLScanConcurrency, Equals, DefDistSQLScanConcurrency) c.Assert(vars.MaxChunkSize, Equals, DefMaxChunkSize) c.Assert(vars.DMLBatchSize, Equals, DefDMLBatchSize) - c.Assert(vars.MemQuotaQuery, Equals, int64(config.GetGlobalConfig().MemQuotaQuery)) + c.Assert(vars.MemQuotaQuery, Equals, config.GetGlobalConfig().MemQuotaQuery) c.Assert(vars.MemQuotaHashJoin, Equals, int64(DefTiDBMemQuotaHashJoin)) c.Assert(vars.MemQuotaMergeJoin, Equals, int64(DefTiDBMemQuotaMergeJoin)) c.Assert(vars.MemQuotaSort, Equals, int64(DefTiDBMemQuotaSort)) diff --git a/tidb-server/main.go b/tidb-server/main.go index 6813198ad0687..761867736f439 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -362,7 +362,7 @@ func loadConfig() string { // hotReloadConfigItems lists all config items which support hot-reload. var hotReloadConfigItems = []string{"Performance.MaxProcs", "Performance.MaxMemory", "Performance.CrossJoin", "Performance.FeedbackProbability", "Performance.QueryFeedbackLimit", "Performance.PseudoEstimateRatio", - "OOMAction", "MemQuotaQuery", "StmtSummary.MaxStmtCount", "StmtSummary.MaxSQLLength"} + "OOMUseTmpStorage", "OOMAction", "MemQuotaQuery", "StmtSummary.MaxStmtCount", "StmtSummary.MaxSQLLength"} func reloadConfig(nc, c *config.Config) { // Just a part of config items need to be reload explicitly. diff --git a/util/memory/action.go b/util/memory/action.go index 9e70752555718..a71d74a877082 100644 --- a/util/memory/action.go +++ b/util/memory/action.go @@ -32,6 +32,9 @@ type ActionOnExceed interface { // SetLogHook binds a log hook which will be triggered and log an detailed // message for the out-of-memory sql. SetLogHook(hook func(uint64)) + // SetFallback sets a fallback action which will be triggered if itself has + // already been triggered. + SetFallback(a ActionOnExceed) } // LogOnExceed logs a warning only once when memory usage exceeds memory quota. @@ -62,6 +65,9 @@ func (a *LogOnExceed) Action(t *Tracker) { } } +// SetFallback sets a fallback action. +func (a *LogOnExceed) SetFallback(ActionOnExceed) {} + // PanicOnExceed panics when memory usage exceeds memory quota. type PanicOnExceed struct { mutex sync.Mutex // For synchronization. @@ -90,6 +96,9 @@ func (a *PanicOnExceed) Action(t *Tracker) { panic(PanicMemoryExceed + fmt.Sprintf("[conn_id=%d]", a.ConnID)) } +// SetFallback sets a fallback action. +func (a *PanicOnExceed) SetFallback(ActionOnExceed) {} + var ( errMemExceedThreshold = terror.ClassExecutor.New(codeMemExceedThreshold, mysql.MySQLErrName[mysql.ErrMemExceedThreshold]) ) diff --git a/util/memory/tracker.go b/util/memory/tracker.go index 9ffda11790391..a28116c85feec 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -42,24 +42,28 @@ type Tracker struct { sync.Mutex children []*Tracker // The children memory trackers } + actionMu struct { + sync.Mutex + actionOnExceed ActionOnExceed + } - label fmt.Stringer // Label of this "Tracker". - bytesConsumed int64 // Consumed bytes. - bytesLimit int64 // Negative value means no limit. - maxConsumed int64 // max number of bytes consumed during execution. - actionOnExceed ActionOnExceed - parent *Tracker // The parent memory tracker. + label fmt.Stringer // Label of this "Tracker". + bytesConsumed int64 // Consumed bytes. + bytesLimit int64 // bytesLimit <= 0 means no limit. + maxConsumed int64 // max number of bytes consumed during execution. + parent *Tracker // The parent memory tracker. } // NewTracker creates a memory tracker. // 1. "label" is the label used in the usage string. // 2. "bytesLimit <= 0" means no limit. func NewTracker(label fmt.Stringer, bytesLimit int64) *Tracker { - return &Tracker{ - label: label, - bytesLimit: bytesLimit, - actionOnExceed: &LogOnExceed{}, + t := &Tracker{ + label: label, + bytesLimit: bytesLimit, } + t.actionMu.actionOnExceed = &LogOnExceed{} + return t } // CheckBytesLimit check whether the bytes limit of the tracker is equal to a value. @@ -76,7 +80,18 @@ func (t *Tracker) SetBytesLimit(bytesLimit int64) { // SetActionOnExceed sets the action when memory usage exceeds bytesLimit. func (t *Tracker) SetActionOnExceed(a ActionOnExceed) { - t.actionOnExceed = a + t.actionMu.Lock() + t.actionMu.actionOnExceed = a + t.actionMu.Unlock() +} + +// FallbackOldAndSetNewAction sets the action when memory usage exceeds bytesLimit +// and set the original action as its fallback. +func (t *Tracker) FallbackOldAndSetNewAction(a ActionOnExceed) { + t.actionMu.Lock() + defer t.actionMu.Unlock() + a.SetFallback(t.actionMu.actionOnExceed) + t.actionMu.actionOnExceed = a } // SetLabel sets the label of a Tracker. @@ -151,12 +166,10 @@ func (t *Tracker) ReplaceChild(oldChild, newChild *Tracker) { // which means this is a memory release operation. When memory usage of a tracker // exceeds its bytesLimit, the tracker calls its action, so does each of its ancestors. func (t *Tracker) Consume(bytes int64) { + var rootExceed *Tracker for tracker := t; tracker != nil; tracker = tracker.parent { if atomic.AddInt64(&tracker.bytesConsumed, bytes) >= tracker.bytesLimit && tracker.bytesLimit > 0 { - // TODO(fengliyuan): try to find a way to avoid logging at each tracker in chain. - if tracker.actionOnExceed != nil { - tracker.actionOnExceed.Action(tracker) - } + rootExceed = tracker } for { @@ -168,6 +181,13 @@ func (t *Tracker) Consume(bytes int64) { break } } + if rootExceed != nil { + rootExceed.actionMu.Lock() + defer rootExceed.actionMu.Unlock() + if rootExceed.actionMu.actionOnExceed != nil { + rootExceed.actionMu.actionOnExceed.Action(rootExceed) + } + } } // BytesConsumed returns the consumed memory usage value in bytes. diff --git a/util/memory/tracker_test.go b/util/memory/tracker_test.go index a40d6d481bbc8..62cc7271f774d 100644 --- a/util/memory/tracker_test.go +++ b/util/memory/tracker_test.go @@ -96,19 +96,42 @@ func (s *testSuite) TestOOMAction(c *C) { c.Assert(action.called, IsFalse) tracker.Consume(10000) c.Assert(action.called, IsTrue) + + // test fallback + action1 := &mockAction{} + action2 := &mockAction{} + tracker.SetActionOnExceed(action1) + tracker.FallbackOldAndSetNewAction(action2) + c.Assert(action1.called, IsFalse) + c.Assert(action2.called, IsFalse) + tracker.Consume(10000) + c.Assert(action1.called, IsFalse) + c.Assert(action2.called, IsTrue) + tracker.Consume(10000) + c.Assert(action1.called, IsTrue) + c.Assert(action2.called, IsTrue) } type mockAction struct { - called bool + called bool + fallback ActionOnExceed } func (a *mockAction) SetLogHook(hook func(uint64)) { } func (a *mockAction) Action(t *Tracker) { + if a.called && a.fallback != nil { + a.fallback.Action(t) + return + } a.called = true } +func (a *mockAction) SetFallback(fallback ActionOnExceed) { + a.fallback = fallback +} + func (s *testSuite) TestAttachTo(c *C) { oldParent := NewTracker(stringutil.StringerStr("old parent"), -1) newParent := NewTracker(stringutil.StringerStr("new parent"), -1) diff --git a/util/mock/context.go b/util/mock/context.go index ba63400ddc5e2..c2a8cae86e811 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -28,7 +28,9 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/kvcache" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tipb/go-binlog" ) @@ -266,6 +268,7 @@ func NewContext() *Context { sctx.sessionVars.InitChunkSize = 2 sctx.sessionVars.MaxChunkSize = 32 sctx.sessionVars.StmtCtx.TimeZone = time.UTC + sctx.sessionVars.StmtCtx.MemTracker = memory.NewTracker(stringutil.StringerStr("mock.NewContext"), -1) sctx.sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor() if err := sctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864"); err != nil { panic(err)