Skip to content

Commit

Permalink
expression: replace mock.Context with StaticExprContext in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao committed Apr 29, 2024
1 parent 0f0418f commit a815865
Show file tree
Hide file tree
Showing 46 changed files with 1,983 additions and 1,884 deletions.
4 changes: 3 additions & 1 deletion pkg/expression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ go_test(
"//pkg/errctx",
"//pkg/errno",
"//pkg/expression/context",
"//pkg/expression/contextopt",
"//pkg/expression/contextstatic",
"//pkg/infoschema/context",
"//pkg/kv",
"//pkg/parser",
"//pkg/parser/ast",
Expand All @@ -225,7 +228,6 @@ go_test(
"//pkg/util/chunk",
"//pkg/util/codec",
"//pkg/util/collate",
"//pkg/util/context",
"//pkg/util/hack",
"//pkg/util/mathutil",
"//pkg/util/mock",
Expand Down
174 changes: 108 additions & 66 deletions pkg/expression/bench_test.go

Large diffs are not rendered by default.

47 changes: 23 additions & 24 deletions pkg/expression/builtin_arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestSetFlenDecimal4RealOrDecimal(t *testing.T) {
}

func TestArithmeticPlus(t *testing.T) {
ctx := createContext(t)
ctx := mockStmtTruncateAsWarningExprCtx(t)
// case: 1
args := []any{int64(12), int64(1)}

Expand All @@ -103,7 +103,7 @@ func TestArithmeticPlus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, intSig)

intResult, isNull, err := intSig.evalInt(ctx, chunk.Row{})
intResult, isNull, err := intSig.evalInt(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.False(t, isNull)
require.Equal(t, int64(13), intResult)
Expand All @@ -118,7 +118,7 @@ func TestArithmeticPlus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, realSig)

realResult, isNull, err := realSig.evalReal(ctx, chunk.Row{})
realResult, isNull, err := realSig.evalReal(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.False(t, isNull)
require.Equal(t, 1.00001, realResult)
Expand All @@ -133,7 +133,7 @@ func TestArithmeticPlus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, realSig)

realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
realResult, isNull, err = realSig.evalReal(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
Expand All @@ -148,7 +148,7 @@ func TestArithmeticPlus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, realSig)

realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
realResult, isNull, err = realSig.evalReal(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
Expand All @@ -165,7 +165,7 @@ func TestArithmeticPlus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, intSig)

intResult, _, err = intSig.evalInt(ctx, chunk.Row{})
intResult, _, err = intSig.evalInt(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.Equal(t, int64(9007199254740993), intResult)

Expand All @@ -182,13 +182,13 @@ func TestArithmeticPlus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, intSig)

intResult, _, err = intSig.evalInt(ctx, chunk.Row{})
intResult, _, err = intSig.evalInt(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.Equal(t, int64(4), intResult)
}

func TestArithmeticMinus(t *testing.T) {
ctx := createContext(t)
ctx := mockStmtTruncateAsWarningExprCtx(t)
// case: 1
args := []any{int64(12), int64(1)}

Expand All @@ -199,7 +199,7 @@ func TestArithmeticMinus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, intSig)

intResult, isNull, err := intSig.evalInt(ctx, chunk.Row{})
intResult, isNull, err := intSig.evalInt(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.False(t, isNull)
require.Equal(t, int64(11), intResult)
Expand All @@ -214,7 +214,7 @@ func TestArithmeticMinus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, realSig)

realResult, isNull, err := realSig.evalReal(ctx, chunk.Row{})
realResult, isNull, err := realSig.evalReal(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.False(t, isNull)
require.Equal(t, 1.02001, realResult)
Expand All @@ -229,7 +229,7 @@ func TestArithmeticMinus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, realSig)

realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
realResult, isNull, err = realSig.evalReal(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
Expand All @@ -244,7 +244,7 @@ func TestArithmeticMinus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, realSig)

realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
realResult, isNull, err = realSig.evalReal(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
Expand All @@ -259,14 +259,14 @@ func TestArithmeticMinus(t *testing.T) {
require.True(t, ok)
require.NotNil(t, realSig)

realResult, isNull, err = realSig.evalReal(ctx, chunk.Row{})
realResult, isNull, err = realSig.evalReal(ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
require.True(t, isNull)
require.Equal(t, float64(0), realResult)
}

func TestArithmeticMultiply(t *testing.T) {
ctx := createContext(t)
ctx := mockStmtTruncateAsWarningExprCtx(t)
testCases := []struct {
args []any
expect []any
Expand Down Expand Up @@ -310,7 +310,7 @@ func TestArithmeticMultiply(t *testing.T) {
sig, err := funcs[ast.Mul].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
require.NotNil(t, sig)
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
val, err := evalBuiltinFunc(sig, ctx.GetEvalCtx(), chunk.Row{})
if tc.expect[1] == nil {
require.NoError(t, err)
testutil.DatumEqual(t, types.NewDatum(tc.expect[0]), val)
Expand All @@ -322,8 +322,7 @@ func TestArithmeticMultiply(t *testing.T) {
}

func TestArithmeticDivide(t *testing.T) {
ctx := createContext(t)

ctx := mockStmtTruncateAsWarningExprCtx(t)
testCases := []struct {
args []any
expect any
Expand Down Expand Up @@ -384,14 +383,14 @@ func TestArithmeticDivide(t *testing.T) {
case *builtinArithmeticIntDivideDecimalSig:
require.Equal(t, tipb.ScalarFuncSig_IntDivideDecimal, sig.PbCode())
}
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
val, err := evalBuiltinFunc(sig, ctx.GetEvalCtx(), chunk.Row{})
require.NoError(t, err)
testutil.DatumEqual(t, types.NewDatum(tc.expect), val)
}
}

func TestArithmeticIntDivide(t *testing.T) {
ctx := createContext(t)
ctx := mockStmtTruncateAsWarningExprCtx(t)
testCases := []struct {
args []any
expect []any
Expand Down Expand Up @@ -494,7 +493,7 @@ func TestArithmeticIntDivide(t *testing.T) {
sig, err := funcs[ast.IntDiv].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
require.NotNil(t, sig)
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
val, err := evalBuiltinFunc(sig, ctx.GetEvalCtx(), chunk.Row{})
if tc.expect[1] == nil {
require.NoError(t, err)
testutil.DatumEqual(t, types.NewDatum(tc.expect[0]), val)
Expand All @@ -506,7 +505,7 @@ func TestArithmeticIntDivide(t *testing.T) {
}

func TestArithmeticMod(t *testing.T) {
ctx := createContext(t)
ctx := mockStmtTruncateAsWarningExprCtx(t)
testCases := []struct {
args []any
expect any
Expand Down Expand Up @@ -637,7 +636,7 @@ func TestArithmeticMod(t *testing.T) {
sig, err := funcs[ast.Mod].getFunction(ctx, datumsToConstants(types.MakeDatums(tc.args...)))
require.NoError(t, err)
require.NotNil(t, sig)
val, err := evalBuiltinFunc(sig, ctx, chunk.Row{})
val, err := evalBuiltinFunc(sig, ctx.GetEvalCtx(), chunk.Row{})
switch sig.(type) {
case *builtinArithmeticModRealSig:
require.Equal(t, tipb.ScalarFuncSig_ModReal, sig.PbCode())
Expand All @@ -658,7 +657,7 @@ func TestArithmeticMod(t *testing.T) {
}

func TestDecimalErrOverflow(t *testing.T) {
ctx := createContext(t)
ctx := mockStmtTruncateAsWarningExprCtx(t)
testCases := []struct {
args []float64
opd string
Expand Down Expand Up @@ -696,7 +695,7 @@ func TestDecimalErrOverflow(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, bf)
require.Equal(t, tc.sig, bf.PbCode())
_, err = evalBuiltinFunc(bf, ctx, chunk.Row{})
_, err = evalBuiltinFunc(bf, ctx.GetEvalCtx(), chunk.Row{})
require.EqualError(t, err, tc.errStr)
}
}
5 changes: 2 additions & 3 deletions pkg/expression/builtin_arithmetic_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -197,7 +196,7 @@ func TestVectorizedBuiltinArithmeticFunc(t *testing.T) {
}

func TestVectorizedDecimalErrOverflow(t *testing.T) {
ctx := mock.NewContext()
ctx := mockStmtExprCtx(t)
testCases := []struct {
args []float64
funcName string
Expand Down Expand Up @@ -231,7 +230,7 @@ func TestVectorizedDecimalErrOverflow(t *testing.T) {
baseFunc, err := funcs[tt.funcName].getFunction(ctx, cols)
require.NoError(t, err)
result := chunk.NewColumn(eType2FieldType(types.ETDecimal), 1)
err = vecEvalType(ctx, baseFunc, types.ETDecimal, input, result)
err = vecEvalType(ctx.GetEvalCtx(), baseFunc, types.ETDecimal, input, result)
require.EqualError(t, err, tt.errStr)
}
}
Expand Down
11 changes: 6 additions & 5 deletions pkg/expression/builtin_cast_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/mock"
)

func genCastIntAsInt(ctx BuildContext) (*builtinCastIntAsIntSig, *chunk.Chunk, *chunk.Column) {
Expand All @@ -41,25 +40,27 @@ func genCastIntAsInt(ctx BuildContext) (*builtinCastIntAsIntSig, *chunk.Chunk, *
}

func BenchmarkCastIntAsIntRow(b *testing.B) {
ctx := mock.NewContext()
ctx := mockStmtExprCtx(b)
evalCtx := ctx.GetEvalCtx()
cast, input, _ := genCastIntAsInt(ctx)
it := chunk.NewIterator4Chunk(input)
b.ResetTimer()
for i := 0; i < b.N; i++ {
for row := it.Begin(); row != it.End(); row = it.Next() {
if _, _, err := cast.evalInt(ctx, row); err != nil {
if _, _, err := cast.evalInt(evalCtx, row); err != nil {
b.Fatal(err)
}
}
}
}

func BenchmarkCastIntAsIntVec(b *testing.B) {
ctx := mock.NewContext()
ctx := mockStmtExprCtx(b)
evalCtx := ctx.GetEvalCtx()
cast, input, result := genCastIntAsInt(ctx)
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := cast.vecEvalInt(ctx, input, result); err != nil {
if err := cast.vecEvalInt(evalCtx, input, result); err != nil {
b.Fatal(err)
}
}
Expand Down

0 comments on commit a815865

Please sign in to comment.