diff --git a/lib/logsql/select.go b/lib/logsql/select.go index caec74b..eb895f4 100644 --- a/lib/logsql/select.go +++ b/lib/logsql/select.go @@ -2479,11 +2479,39 @@ func (v *selectTranslatorVisitor) translateExpr(expr ast.Expr) (string, error) { } func (v *selectTranslatorVisitor) translateComparison(left, right ast.Expr, cmp comparisonKind) (string, error) { - field, lit, flipped, err := v.extractFieldAndLiteral(left, right) + leftField, leftIsField, err := v.fieldNameFromExpr(left) if err != nil { return "", err } + rightField, rightIsField, err := v.fieldNameFromExpr(right) + if err != nil { + return "", err + } + + switch { + case leftIsField && rightIsField: + return translateFieldComparison(leftField, rightField, cmp) + case leftIsField: + lit, err := literalFromExpr(right) + if err != nil { + return "", err + } + return buildFieldLiteralComparison(leftField, lit, false, cmp) + case rightIsField: + lit, err := literalFromExpr(left) + if err != nil { + return "", err + } + return buildFieldLiteralComparison(rightField, lit, true, cmp) + default: + return "", &TranslationError{ + Code: http.StatusBadRequest, + Message: "translator: comparison requires identifier and literal", + } + } +} +func buildFieldLiteralComparison(field string, lit literalValue, flipped bool, cmp comparisonKind) (string, error) { switch cmp { case comparisonEqual: clause := field + ":" + lit.format() @@ -2519,6 +2547,31 @@ func (v *selectTranslatorVisitor) translateComparison(left, right ast.Expr, cmp } } +func translateFieldComparison(leftField, rightField string, cmp comparisonKind) (string, error) { + switch cmp { + case comparisonEqual: + return fmt.Sprintf("%s:eq_field(%s)", leftField, rightField), nil + case comparisonNotEqual: + clause := fmt.Sprintf("%s:eq_field(%s)", leftField, rightField) + return "-" + clause, nil + case comparisonLess: + return fmt.Sprintf("%s:lt_field(%s)", leftField, rightField), nil + case comparisonLessEqual: + return fmt.Sprintf("%s:le_field(%s)", leftField, rightField), nil + case comparisonGreater: + clause := fmt.Sprintf("%s:le_field(%s)", leftField, rightField) + return "-" + clause, nil + case comparisonGreaterEqual: + clause := fmt.Sprintf("%s:lt_field(%s)", leftField, rightField) + return "-" + clause, nil + default: + return "", &TranslationError{ + Code: http.StatusBadRequest, + Message: "translator: unsupported comparison kind", + } + } +} + func (v *selectTranslatorVisitor) translateBetweenExpr(expr *ast.BetweenExpr) (string, error) { if expr == nil { return "", &TranslationError{ @@ -2749,31 +2802,6 @@ func (v *selectTranslatorVisitor) rawFieldName(ident *ast.Identifier) (string, e return field, nil } -func (v *selectTranslatorVisitor) extractFieldAndLiteral(left, right ast.Expr) (string, literalValue, bool, error) { - if name, ok, err := v.fieldNameFromExpr(left); err != nil { - return "", literalValue{}, false, err - } else if ok { - lit, err := literalFromExpr(right) - if err != nil { - return "", literalValue{}, false, err - } - return name, lit, false, nil - } - if name, ok, err := v.fieldNameFromExpr(right); err != nil { - return "", literalValue{}, false, err - } else if ok { - lit, err := literalFromExpr(left) - if err != nil { - return "", literalValue{}, false, err - } - return name, lit, true, nil - } - return "", literalValue{}, false, &TranslationError{ - Code: http.StatusBadRequest, - Message: "translator: comparison requires identifier and literal", - } -} - func (v *selectTranslatorVisitor) fieldNameFromExpr(expr ast.Expr) (string, bool, error) { switch e := expr.(type) { case *ast.Identifier: diff --git a/lib/logsql/select_test.go b/lib/logsql/select_test.go index 862e593..f36bc6c 100644 --- a/lib/logsql/select_test.go +++ b/lib/logsql/select_test.go @@ -201,6 +201,41 @@ func TestToLogsQLSuccess(t *testing.T) { sql: "SELECT * FROM logs WHERE message LIKE '_foo'", expected: "message:~\"^.foo$\"", }, + { + name: "compare fields equality", + sql: "SELECT * FROM logs WHERE user_id = customer_id", + expected: "user_id:eq_field(customer_id)", + }, + { + name: "compare fields inequality", + sql: "SELECT * FROM logs WHERE duration != max_duration", + expected: "-duration:eq_field(max_duration)", + }, + { + name: "compare fields less than", + sql: "SELECT * FROM logs WHERE duration < max_duration", + expected: "duration:lt_field(max_duration)", + }, + { + name: "compare fields less or equal", + sql: "SELECT * FROM logs WHERE duration <= max_duration", + expected: "duration:le_field(max_duration)", + }, + { + name: "compare fields greater than", + sql: "SELECT * FROM logs WHERE duration > max_duration", + expected: "-duration:le_field(max_duration)", + }, + { + name: "compare fields greater or equal", + sql: "SELECT * FROM logs WHERE duration >= max_duration", + expected: "-duration:lt_field(max_duration)", + }, + { + name: "compare function fields equality", + sql: "SELECT * FROM logs WHERE LOWER(user) = LOWER(customer)", + expected: "* | format \"\" as __filter_expr_1 | format \"\" as __filter_expr_2 | filter __filter_expr_1:eq_field(__filter_expr_2) | delete __filter_expr_1, __filter_expr_2", + }, { name: "arithmetic projection", sql: "SELECT (duration_ms / 1000) AS duration_s FROM logs",