diff --git a/EvaluableExpression_sql.go b/EvaluableExpression_sql.go index 7e0ad1c..2f99011 100644 --- a/EvaluableExpression_sql.go +++ b/EvaluableExpression_sql.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "regexp" + "strings" "time" ) @@ -157,7 +158,13 @@ func (this EvaluableExpression) findNextSQLString(stream *tokenStream, transacti ret = ")" case SEPARATOR: ret = "," - + case MAP: + out := strings.Builder{} + for k, _ := range token.Value.(map[interface{}]struct{}) { + out.WriteString(fmt.Sprintf("%v", k)) + out.WriteString(" , ") + } + return "( "+strings.TrimRight(out.String(), " , ")+" )", nil default: errorMsg := fmt.Sprintf("Unrecognized query token '%s' of kind '%s'", token.Value, token.Kind) return "", errors.New(errorMsg) diff --git a/TokenKind.go b/TokenKind.go index 7c9516d..0354a63 100644 --- a/TokenKind.go +++ b/TokenKind.go @@ -27,6 +27,8 @@ const ( CLAUSE_CLOSE TERNARY + + MAP ) /* @@ -47,6 +49,8 @@ func (kind TokenKind) String() string { return "STRING" case PATTERN: return "PATTERN" + case MAP: + return "MAP" case TIME: return "TIME" case VARIABLE: diff --git a/benchmarks_test.go b/benchmarks_test.go index 23eaf67..09bb02b 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -249,3 +249,20 @@ func BenchmarkNestedAccessors(bench *testing.B) { expression.Evaluate(fooFailureParameters) } } + +func BenchmarkIn(bench *testing.B) { + + expressionString := "a in (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)" + expression, _ := NewEvaluableExpression(expressionString) + + bench.ResetTimer() + var val interface{} + for i := 0; i < bench.N; i++ { + val, _ = expression.Evaluate(map[string]interface{}{ + "a": 4, + }) + } + if val != true { + bench.Error("expected true") + } +} diff --git a/evaluationStage.go b/evaluationStage.go index 11ea587..b9ede31 100644 --- a/evaluationStage.go +++ b/evaluationStage.go @@ -419,11 +419,16 @@ func separatorStage(left interface{}, right interface{}, parameters Parameters) } func inStage(left interface{}, right interface{}, parameters Parameters) (interface{}, error) { - - for _, value := range right.([]interface{}) { - if left == value { - return true, nil + switch t := right.(type) { + case []interface{}: + for _, value := range t { + if left == value { + return true, nil + } } + case map[interface{}]struct{}: + _, hit := t[left] + return hit, nil } return false, nil } @@ -496,10 +501,12 @@ func comparatorTypeCheck(left interface{}, right interface{}) bool { return false } -func isArray(value interface{}) bool { +func isArrayOrMap(value interface{}) bool { switch value.(type) { case []interface{}: return true + case map[interface{}]struct{}: + return true } return false } diff --git a/parsing.go b/parsing.go index 40c7ed2..9d80214 100644 --- a/parsing.go +++ b/parsing.go @@ -350,40 +350,98 @@ func readUntilFalse(stream *lexerStream, includeWhitespace bool, breakWhitespace */ func optimizeTokens(tokens []ExpressionToken) ([]ExpressionToken, error) { - var token ExpressionToken var symbol OperatorSymbol var err error - var index int - for index, token = range tokens { + for index := 0; index < len(tokens); index++ { + token := tokens[index] - // if we find a regex operator, and the right-hand value is a constant, precompile and replace with a pattern. if token.Kind != COMPARATOR { continue } symbol = comparatorSymbols[token.Value.(string)] - if symbol != REQ && symbol != NREQ { - continue - } + switch symbol { + case REQ, NREQ: // if we find a regex operator, and the right-hand value is a constant, precompile and replace with a pattern. + nextToken := tokens[index+1] + if nextToken.Kind == STRING { - index++ - token = tokens[index] - if token.Kind == STRING { + nextToken.Kind = PATTERN + nextToken.Value, err = regexp.Compile(nextToken.Value.(string)) - token.Kind = PATTERN - token.Value, err = regexp.Compile(token.Value.(string)) + if err != nil { + return tokens, err + } + + tokens[index+1] = nextToken + } + case IN: + nextToken := tokens[index+1] + if nextToken.Kind != CLAUSE { + continue + } + + isComp, endIndex, err := clauseIsComparable(tokens, index+1) if err != nil { - return tokens, err + return nil, err } + if !isComp { + break // switch + } + + mp := clauseToMap(tokens, index) + nextToken.Kind = MAP + nextToken.Value = mp + tokens[index+1] = nextToken - tokens[index] = token + // remove all tokens that have been condensed into map + newTokens := make([]ExpressionToken, 0, len(tokens)-(endIndex-index)) + newTokens = append(newTokens, tokens[:index+2]...) + newTokens = append(newTokens, tokens[endIndex+1:]...) + tokens = newTokens } } + return tokens, nil } +func clauseIsComparable(tokens []ExpressionToken, index int) (bool, int, error) { + if tokens[index].Kind != CLAUSE { + return false, 0, fmt.Errorf("token at index %d is %s, expected %s", index, tokens[index].Kind, CLAUSE) + } +loop: + for { + index++ + token := tokens[index] + switch token.Kind { + case CLAUSE_CLOSE: + break loop + case SEPARATOR, STRING, NUMERIC: + continue + default: + return false, index, nil + } + } + return true, index, nil +} + +func clauseToMap(tokens []ExpressionToken, index int) map[interface{}]struct{} { + mp := make(map[interface{}]struct{}) +loop: + for { + index++ + token := tokens[index] + switch token.Kind { + case CLAUSE_CLOSE: + break loop + case STRING, NUMERIC: + mp[token.Value] = struct{}{} + } + } + return mp +} + /* Checks the balance of tokens which have multiple parts, such as parenthesis. */ diff --git a/parsing_test.go b/parsing_test.go index d57b809..433e77a 100644 --- a/parsing_test.go +++ b/parsing_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "reflect" + "sort" "testing" "time" "unicode" @@ -878,9 +879,8 @@ func TestComparatorParsing(test *testing.T) { }, }, TokenParsingTest{ - - Name: "Array membership lowercase", - Input: "'foo' in ('foo', 'bar')", + Name: "Array membership complex entries", // this "in" lookup is not optimized via MAP + Input: "'foo' in ('foo', 1, 1 + 2)", Expected: []ExpressionToken{ ExpressionToken{ Kind: STRING, @@ -901,8 +901,23 @@ func TestComparatorParsing(test *testing.T) { Kind: SEPARATOR, }, ExpressionToken{ - Kind: STRING, - Value: "bar", + Kind: NUMERIC, + Value: 1.0, + }, + ExpressionToken{ + Kind: SEPARATOR, + }, + ExpressionToken{ + Kind: NUMERIC, + Value: 1.0, + }, + ExpressionToken{ + Kind: MODIFIER, + Value: "+", + }, + ExpressionToken{ + Kind: NUMERIC, + Value: 2.0, }, ExpressionToken{ Kind: CLAUSE_CLOSE, @@ -911,8 +926,8 @@ func TestComparatorParsing(test *testing.T) { }, TokenParsingTest{ - Name: "Array membership uppercase", - Input: "'foo' IN ('foo', 'bar')", + Name: "Array membership lowercase", // "in" lookup optimized via MAP + Input: "'foo' in ('foo', 'bar')", Expected: []ExpressionToken{ ExpressionToken{ Kind: STRING, @@ -923,21 +938,25 @@ func TestComparatorParsing(test *testing.T) { Value: "in", }, ExpressionToken{ - Kind: CLAUSE, + Kind: MAP, }, + }, + }, + TokenParsingTest{ + + Name: "Array membership uppercase", // "in" lookup optimized via MAP + Input: "'foo' IN ('foo', 'bar')", + Expected: []ExpressionToken{ ExpressionToken{ Kind: STRING, Value: "foo", }, ExpressionToken{ - Kind: SEPARATOR, - }, - ExpressionToken{ - Kind: STRING, - Value: "bar", + Kind: COMPARATOR, + Value: "in", }, ExpressionToken{ - Kind: CLAUSE_CLOSE, + Kind: MAP, }, }, }, @@ -1668,3 +1687,83 @@ func runTokenParsingTest(tokenParsingTests []TokenParsingTest, test *testing.T) func noop(arguments ...interface{}) (interface{}, error) { return nil, nil } + +func TestClauseIsComparable(t *testing.T) { + cases := []struct{ + expression string + expected bool + index int // zero based + endIndex int + } { + { + expression: "a in (1, 2, 3) && b", + expected: true, + index: 2, + endIndex: 8, + }, + { + expression: "a in (1, 'a') && b", + expected: true, + index: 2, + endIndex: 6, + }, + { + expression: "a in (1, b, 3) && b", + expected: false, + index: 2, + }, + { + expression: "a in (1, 1+1, 3) && b", + expected: false, + index: 2, + }, + } + + for _, c := range cases { + tokens, err := parseTokens(c.expression, nil) + if err != nil { + t.Fatal(err) + } + isComparable, endIndex, err := clauseIsComparable(tokens, c.index) + if err != nil { + t.Fatal(err) + } + if isComparable != c.expected { + t.Errorf("unexpected clause comparable result") + } + if !isComparable { + continue + } + if endIndex != c.endIndex { + t.Errorf("unexpected end index, expected %d, got %d", c.endIndex, endIndex) + } + } +} + +func TestClauseToMap(t *testing.T) { + tokens, err := parseTokens("a in (1, 2, 3) && b", nil) + if err != nil { + t.Fatal(err) + } + mp := clauseToMap(tokens, 2) + + keys := mapKeys(mp) + expected := []interface{}{1.0, 2.0, 3.0} + + if !reflect.DeepEqual(keys, expected) { + t.Errorf("unexpected clause map, expected %v, got %v", expected, mp) + } +} + +func mapKeys(mp map[interface{}]struct{}) []interface{} { + var keys []interface{} + for k, _ := range mp { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + a := keys[i] + b := keys[j] + return a.(float64) < b.(float64) + }) + return keys +} diff --git a/stagePlanner.go b/stagePlanner.go index d71ed12..ec37614 100644 --- a/stagePlanner.go +++ b/stagePlanner.go @@ -421,6 +421,8 @@ func planValue(stream *tokenStream) (*evaluationStage, error) { fallthrough case PATTERN: fallthrough + case MAP: + fallthrough case BOOLEAN: symbol = LITERAL operator = makeLiteralStage(token.Value) @@ -486,7 +488,7 @@ func findTypeChecks(symbol OperatorSymbol) typeChecks { } case IN: return typeChecks{ - right: isArray, + right: isArrayOrMap, } case BITWISE_LSHIFT: fallthrough