diff --git a/lib/logsql/select.go b/lib/logsql/select.go index 0895ef9..8fda47e 100644 --- a/lib/logsql/select.go +++ b/lib/logsql/select.go @@ -5,6 +5,7 @@ import ( "maps" "net/http" "regexp" + "sort" "strconv" "strings" @@ -44,6 +45,8 @@ var ( safeFormatFieldLiteral = regexp.MustCompile(`^[A-Za-z0-9_.-]+$`) ) +const constantSelectBasePipeline = "* | limit 1 | delete *" + type selectTranslatorVisitor struct { result string err error @@ -64,6 +67,11 @@ type selectTranslatorVisitor struct { filterOrder []string filterDelete []string filterDeleteSet map[string]struct{} + constantFields map[string]string + constantFieldCount int + aggTempDeletes map[string]string + aggPreserve map[string]struct{} + constantBase bool } type tableBinding struct { @@ -185,6 +193,11 @@ func (v *selectTranslatorVisitor) translateSimpleSelect(stmt *ast.SelectStatemen v.filterOrder = nil v.filterDelete = nil v.filterDeleteSet = nil + v.constantFields = nil + v.constantFieldCount = 0 + v.aggTempDeletes = nil + v.aggPreserve = nil + v.constantBase = false joinPipes, err := v.processFrom(stmt.From) if err != nil { @@ -246,7 +259,7 @@ func (v *selectTranslatorVisitor) translateSimpleSelect(stmt *ast.SelectStatemen } pipes = append(pipes, joinPipes...) - statsPipes, aggregated, err := v.buildStatsPipe(stmt) + statsPipes, aggregated, err := v.buildStatsPipe(stmt, stmt.Having) if err != nil { return "", err } @@ -266,6 +279,41 @@ func (v *selectTranslatorVisitor) translateSimpleSelect(stmt *ast.SelectStatemen return "", err } pipes = append(pipes, "filter "+havingStr) + if len(v.aggTempDeletes) > 0 { + if len(stmt.OrderBy) > 0 { + for _, item := range stmt.OrderBy { + fn, ok := item.Expr.(*ast.FuncCall) + if !ok { + continue + } + if !isAggregateFunction(fn) { + continue + } + key, err := v.aggregateKeyFromFunc(fn) + if err != nil { + return "", err + } + v.preserveAggregate(key) + } + } + keys := make([]string, 0, len(v.aggTempDeletes)) + for key := range v.aggTempDeletes { + if v.aggPreserve != nil { + if _, ok := v.aggPreserve[key]; ok { + continue + } + } + keys = append(keys, key) + } + if len(keys) > 0 { + sort.Strings(keys) + deleteVals := make([]string, 0, len(keys)) + for _, key := range keys { + deleteVals = append(deleteVals, v.aggTempDeletes[key]) + } + pipes = append(pipes, "delete "+strings.Join(deleteVals, ", ")) + } + } } projectionPipes, projectionFields, err := v.buildProjectionPipes(stmt.Columns, aggregated) @@ -389,11 +437,14 @@ func (v *selectTranslatorVisitor) buildDistinctPipe(fields []string, aggregated func (v *selectTranslatorVisitor) processFrom(from ast.TableExpr) ([]string, error) { if from == nil { - return nil, &TranslationError{ - Code: http.StatusBadRequest, - Message: "translator: FROM clause is required", - } + v.baseAlias = "" + v.baseUsesPipeline = true + v.basePipeline = constantSelectBasePipeline + v.baseFilter = "" + v.constantBase = true + return nil, nil } + v.constantBase = false switch t := from.(type) { case *ast.TableName: @@ -966,13 +1017,35 @@ func (v *selectTranslatorVisitor) ensureAliases(expr ast.Expr, allowed map[strin return nil } -func (v *selectTranslatorVisitor) buildStatsPipe(stmt *ast.SelectStatement) ([]string, bool, error) { +func (v *selectTranslatorVisitor) buildStatsPipe(stmt *ast.SelectStatement, having ast.Expr) ([]string, bool, error) { hasGroup := len(stmt.GroupBy) > 0 aggregates := make([]aggItem, 0) groupFields := make([]string, 0) groupLookup := make(map[string]struct{}) preGroupPipes := make([]string, 0) aliasSources := v.collectGroupAliases(stmt.Columns) + aggIndex := make(map[string]int) + + addAggregate := func(item aggItem) { + if idx, exists := aggIndex[item.key]; exists { + if len(item.prePipes) > 0 { + existing := aggregates[idx] + existing.prePipes = append(existing.prePipes, item.prePipes...) + if item.selected { + existing.selected = true + } + aggregates[idx] = existing + } + if item.selected && !aggregates[idx].selected { + existing := aggregates[idx] + existing.selected = true + aggregates[idx] = existing + } + return + } + aggIndex[item.key] = len(aggregates) + aggregates = append(aggregates, item) + } if hasGroup { v.groupExprAliases = make(map[string]string) @@ -1055,7 +1128,8 @@ func (v *selectTranslatorVisitor) buildStatsPipe(stmt *ast.SelectStatement) ([]s if err != nil { return nil, false, err } - aggregates = append(aggregates, item) + item.selected = true + addAggregate(item) } else if hasGroup { if _, ok, err := v.lookupGroupExpr(expr); err != nil { return nil, false, err @@ -1103,13 +1177,26 @@ func (v *selectTranslatorVisitor) buildStatsPipe(stmt *ast.SelectStatement) ([]s } } + if having != nil { + if err := v.collectAggregatesFromExpr(having, addAggregate); err != nil { + return nil, false, err + } + } + if len(aggregates) == 0 { if hasGroup { - return nil, false, &TranslationError{ - Code: http.StatusBadRequest, - Message: "translator: GROUP BY requires aggregate expressions", + v.aggResults = make(map[string]string) + v.aggTempDeletes = nil + var pipe string + if len(groupFields) > 0 { + pipe = fmt.Sprintf("uniq by (%s)", strings.Join(groupFields, ", ")) + } else { + pipe = "uniq" } + pipes := append(preGroupPipes, pipe) + return pipes, true, nil } + v.aggResults = nil return nil, false, nil } @@ -1121,6 +1208,12 @@ func (v *selectTranslatorVisitor) buildStatsPipe(stmt *ast.SelectStatement) ([]s builder.WriteString(")") } + for _, agg := range aggregates { + if len(agg.prePipes) > 0 { + preGroupPipes = append(preGroupPipes, agg.prePipes...) + } + } + funcs := make([]string, 0, len(aggregates)) aggResults := make(map[string]string) for _, agg := range aggregates { @@ -1131,10 +1224,57 @@ func (v *selectTranslatorVisitor) buildStatsPipe(stmt *ast.SelectStatement) ([]s builder.WriteString(strings.Join(funcs, ", ")) v.aggResults = aggResults + deleteTargets := make(map[string]string) + for _, agg := range aggregates { + if agg.selected { + continue + } + deleteTargets[agg.key] = formatFieldName(agg.resultName) + } + if len(deleteTargets) > 0 { + v.aggTempDeletes = deleteTargets + } else { + v.aggTempDeletes = nil + } pipes := append(preGroupPipes, builder.String()) return pipes, true, nil } +func (v *selectTranslatorVisitor) collectAggregatesFromExpr(expr ast.Expr, add func(aggItem)) error { + if expr == nil { + return nil + } + funcs := make([]*ast.FuncCall, 0) + walkExpr(expr, func(e ast.Expr) { + if fn, ok := e.(*ast.FuncCall); ok { + if isAggregateFunction(fn) { + funcs = append(funcs, fn) + } + } + }) + for _, fn := range funcs { + if fn.Over != nil { + return &TranslationError{ + Code: http.StatusBadRequest, + Message: "translator: window functions are not supported in HAVING", + } + } + item, err := v.analyzeAggregate(fn, "") + if err != nil { + return err + } + add(item) + } + return nil +} + +func (v *selectTranslatorVisitor) preserveAggregate(key string) { + if v.aggPreserve == nil { + v.aggPreserve = make(map[string]struct{}) + } + v.aggPreserve[key] = struct{}{} +} + func (v *selectTranslatorVisitor) prepareGroupByField(expr ast.Expr, index int) (string, []string, error) { switch e := expr.(type) { case *ast.Identifier: @@ -1162,7 +1302,7 @@ func (v *selectTranslatorVisitor) prepareGroupByField(expr ast.Expr, index int) return "", nil, err } return aliasName, []string{mathPipe}, nil - case *ast.BinaryExpr, *ast.UnaryExpr, *ast.NumericLiteral: + case *ast.BinaryExpr, *ast.UnaryExpr, *ast.NumericLiteral, *ast.StringLiteral: alias := fmt.Sprintf("group_%d", index+1) mathPipe, aliasName, err := v.translateMathProjection(expr, alias) if err != nil { @@ -1251,6 +1391,8 @@ type aggItem struct { key string statsCall string resultName string + prePipes []string + selected bool } func (v *selectTranslatorVisitor) analyzeAggregate(fn *ast.FuncCall, alias string) (aggItem, error) { @@ -1275,24 +1417,41 @@ func (v *selectTranslatorVisitor) analyzeAggregate(fn *ast.FuncCall, alias strin } name := strings.ToUpper(fn.Name.Parts[len(fn.Name.Parts)-1]) - var arg string + var ( + keyArg string + callArg string + prePipes []string + ) switch name { case "COUNT": if len(fn.Args) == 0 { - arg = "*" + keyArg = "*" + callArg = "*" } else if len(fn.Args) == 1 { if _, ok := fn.Args[0].(*ast.StarExpr); ok { - arg = "*" + keyArg = "*" + callArg = "*" } else if ident, ok := fn.Args[0].(*ast.Identifier); ok { field, err := v.normalizeIdentifier(ident) if err != nil { return aggItem{}, err } - arg = field + keyArg = field + callArg = field + } else if lit, ok := fn.Args[0].(*ast.NumericLiteral); ok { + keyArg = lit.Value + field, pipe, err := v.ensureConstantField(lit.Value) + if err != nil { + return aggItem{}, err + } + callArg = field + if pipe != "" { + prePipes = append(prePipes, pipe) + } } else { return aggItem{}, &TranslationError{ Code: http.StatusBadRequest, - Message: "translator: COUNT only supports identifiers or *", + Message: "translator: COUNT only supports identifiers, numeric literals, or *", } } } else { @@ -1308,18 +1467,30 @@ func (v *selectTranslatorVisitor) analyzeAggregate(fn *ast.FuncCall, alias strin Message: fmt.Sprintf("translator: %s expects single argument", strings.ToLower(name)), } } - ident, ok := fn.Args[0].(*ast.Identifier) - if !ok { + switch argExpr := fn.Args[0].(type) { + case *ast.Identifier: + field, err := v.normalizeIdentifier(argExpr) + if err != nil { + return aggItem{}, err + } + keyArg = field + callArg = field + case *ast.NumericLiteral: + keyArg = argExpr.Value + field, pipe, err := v.ensureConstantField(argExpr.Value) + if err != nil { + return aggItem{}, err + } + callArg = field + if pipe != "" { + prePipes = append(prePipes, pipe) + } + default: return aggItem{}, &TranslationError{ Code: http.StatusBadRequest, - Message: fmt.Sprintf("translator: %s only supports identifiers", strings.ToLower(name)), + Message: fmt.Sprintf("translator: %s only supports identifiers or numeric literals", strings.ToLower(name)), } } - field, err := v.normalizeIdentifier(ident) - if err != nil { - return aggItem{}, err - } - arg = field default: return aggItem{}, &TranslationError{ Code: http.StatusBadRequest, @@ -1327,15 +1498,15 @@ func (v *selectTranslatorVisitor) analyzeAggregate(fn *ast.FuncCall, alias strin } } - key := aggregateKey(name, arg) - fnCall := fmt.Sprintf("%s(%s)", strings.ToLower(name), formatAggregateArg(arg)) + key := aggregateKey(name, keyArg) + fnCall := fmt.Sprintf("%s(%s)", strings.ToLower(name), formatAggregateArg(callArg)) alias = strings.TrimSpace(alias) if alias == "" { - return aggItem{key: key, statsCall: fnCall, resultName: fnCall}, nil + return aggItem{key: key, statsCall: fnCall, resultName: fnCall, prePipes: prePipes}, nil } formattedAlias := formatFieldName(alias) call := fmt.Sprintf("%s %s", fnCall, formattedAlias) - return aggItem{key: key, statsCall: call, resultName: formattedAlias}, nil + return aggItem{key: key, statsCall: call, resultName: formattedAlias, prePipes: prePipes}, nil } func isAggregateFunction(fn *ast.FuncCall) bool { @@ -1531,6 +1702,26 @@ func sanitizeAliasFromField(field string) string { return value } +func (v *selectTranslatorVisitor) ensureConstantField(value string) (string, string, error) { + if strings.TrimSpace(value) == "" { + return "", "", &TranslationError{ + Code: http.StatusBadRequest, + Message: "translator: constant aggregate requires non-empty numeric literal", + } + } + if v.constantFields == nil { + v.constantFields = make(map[string]string) + } + if field, ok := v.constantFields[value]; ok { + return field, "", nil + } + v.constantFieldCount++ + field := fmt.Sprintf("__const_%d", v.constantFieldCount) + pipe := fmt.Sprintf("format %s as %s", value, field) + v.constantFields[value] = field + return field, pipe, nil +} + func escapeFormatPattern(pattern string) string { pattern = strings.ReplaceAll(pattern, "\\", "\\\\") pattern = strings.ReplaceAll(pattern, "\"", "\\\"") @@ -1812,8 +2003,9 @@ func (v *selectTranslatorVisitor) translateWindowFunction(fn *ast.FuncCall, alia } name := strings.ToUpper(fn.Name.Parts[len(fn.Name.Parts)-1]) var ( - statsCall string - aliasSource string + statsCall string + aliasSource string + constantPipe string ) switch name { case "SUM", "MIN", "MAX": @@ -1826,19 +2018,28 @@ func (v *selectTranslatorVisitor) translateWindowFunction(fn *ast.FuncCall, alia if err := v.ensureBaseAliasesOnly(fn.Args[0]); err != nil { return nil, "", err } - ident, ok := fn.Args[0].(*ast.Identifier) - if !ok { + switch arg := fn.Args[0].(type) { + case *ast.Identifier: + field, err := v.normalizeIdentifier(arg) + if err != nil { + return nil, "", err + } + statsCall = fmt.Sprintf("%s(%s)", strings.ToLower(name), field) + aliasSource = field + case *ast.NumericLiteral: + field, pipe, err := v.ensureConstantField(arg.Value) + if err != nil { + return nil, "", err + } + statsCall = fmt.Sprintf("%s(%s)", strings.ToLower(name), field) + aliasSource = arg.Value + constantPipe = pipe + default: return nil, "", &TranslationError{ Code: http.StatusBadRequest, - Message: fmt.Sprintf("translator: %s window function requires identifier argument", strings.ToLower(name)), + Message: fmt.Sprintf("translator: %s window function requires identifier or numeric literal argument", strings.ToLower(name)), } } - field, err := v.normalizeIdentifier(ident) - if err != nil { - return nil, "", err - } - statsCall = fmt.Sprintf("%s(%s)", strings.ToLower(name), field) - aliasSource = field case "COUNT": if len(fn.Args) == 0 { statsCall = "count()" @@ -1858,10 +2059,20 @@ func (v *selectTranslatorVisitor) translateWindowFunction(fn *ast.FuncCall, alia } statsCall = fmt.Sprintf("count(%s)", field) aliasSource = field + case *ast.NumericLiteral: + field, pipe, err := v.ensureConstantField(arg.Value) + if err != nil { + return nil, "", err + } + statsCall = fmt.Sprintf("count(%s)", field) + aliasSource = arg.Value + if pipe != "" { + constantPipe = pipe + } default: return nil, "", &TranslationError{ Code: http.StatusBadRequest, - Message: "translator: COUNT window function only supports identifiers or *", + Message: "translator: COUNT window function only supports identifiers, numeric literals, or *", } } } else { @@ -1918,6 +2129,9 @@ func (v *selectTranslatorVisitor) translateWindowFunction(fn *ast.FuncCall, alia } pipes = append(pipes, orderPipe) } + if constantPipe != "" { + pipes = append(pipes, constantPipe) + } statsPipe := fmt.Sprintf("running_stats%s %s as %s", partitionClause, statsCall, aliasName) pipes = append(pipes, statsPipe) return pipes, aliasName, nil @@ -2236,7 +2450,78 @@ func (v *selectTranslatorVisitor) buildProjectionPipes(columns []ast.SelectItem, } computedPipes = append(computedPipes, mathPipe) fields = append(fields, formatFieldName(aliasName)) - case *ast.BinaryExpr, *ast.UnaryExpr, *ast.NumericLiteral: + case *ast.NumericLiteral: + if aggregated { + groupField, ok, err := v.lookupGroupExpr(col.Expr) + if err != nil { + return nil, nil, err + } + if !ok { + return nil, nil, &TranslationError{ + Code: http.StatusBadRequest, + Message: fmt.Sprintf("translator: unsupported expression %T in aggregated select", expr), + } + } + finalName := groupField + if alias := strings.TrimSpace(col.Alias); alias != "" { + formattedAlias := formatFieldName(alias) + if formattedAlias != groupField { + renamePairs = append(renamePairs, fmt.Sprintf("%s as %s", groupField, formattedAlias)) + } + finalName = formattedAlias + } + fields = append(fields, finalName) + continue + } + aliasTrim := strings.TrimSpace(col.Alias) + if aliasTrim == "" && v.constantBase { + computedPipes = append(computedPipes, fmt.Sprintf("format %s", expr.Value)) + continue + } + aliasName, err := makeProjectionAlias(aliasTrim, "literal", expr.Value) + if err != nil { + return nil, nil, err + } + pipe := fmt.Sprintf("format %s as %s", expr.Value, formatFieldName(aliasName)) + computedPipes = append(computedPipes, pipe) + fields = append(fields, formatFieldName(aliasName)) + case *ast.StringLiteral: + if aggregated { + groupField, ok, err := v.lookupGroupExpr(col.Expr) + if err != nil { + return nil, nil, err + } + if !ok { + return nil, nil, &TranslationError{ + Code: http.StatusBadRequest, + Message: fmt.Sprintf("translator: unsupported expression %T in aggregated select", expr), + } + } + finalName := groupField + if alias := strings.TrimSpace(col.Alias); alias != "" { + formattedAlias := formatFieldName(alias) + if formattedAlias != groupField { + renamePairs = append(renamePairs, fmt.Sprintf("%s as %s", groupField, formattedAlias)) + } + finalName = formattedAlias + } + fields = append(fields, finalName) + continue + } + aliasTrim := strings.TrimSpace(col.Alias) + value := quoteString(expr.Value) + if aliasTrim == "" && v.constantBase { + computedPipes = append(computedPipes, fmt.Sprintf("format %s", value)) + continue + } + aliasName, err := makeProjectionAlias(aliasTrim, "literal", expr.Value) + if err != nil { + return nil, nil, err + } + pipe := fmt.Sprintf("format %s as %s", value, formatFieldName(aliasName)) + computedPipes = append(computedPipes, pipe) + fields = append(fields, formatFieldName(aliasName)) + case *ast.BinaryExpr, *ast.UnaryExpr: if aggregated { groupField, ok, err := v.lookupGroupExpr(col.Expr) if err != nil { @@ -2283,7 +2568,7 @@ func (v *selectTranslatorVisitor) buildProjectionPipes(columns []ast.SelectItem, if len(renamePairs) > 0 { pipes = append(pipes, "rename "+strings.Join(renamePairs, ", ")) } - if len(fields) > 0 && !aggregated { + if len(fields) > 0 && (!aggregated || len(v.aggResults) == 0) { pipes = append(pipes, "fields "+strings.Join(fields, ", ")) } return pipes, fields, nil @@ -2362,10 +2647,12 @@ func (v *selectTranslatorVisitor) aggregateKeyFromFunc(fn *ast.FuncCall) (string return "", err } arg = field + } else if lit, ok := fn.Args[0].(*ast.NumericLiteral); ok { + arg = lit.Value } else { return "", &TranslationError{ Code: http.StatusBadRequest, - Message: "translator: COUNT only supports identifiers or *", + Message: "translator: COUNT only supports identifiers, numeric literals, or *", } } } else { @@ -2381,18 +2668,21 @@ func (v *selectTranslatorVisitor) aggregateKeyFromFunc(fn *ast.FuncCall) (string Message: fmt.Sprintf("translator: %s expects single argument", strings.ToLower(name)), } } - ident, ok := fn.Args[0].(*ast.Identifier) - if !ok { + switch argExpr := fn.Args[0].(type) { + case *ast.Identifier: + field, err := v.normalizeIdentifier(argExpr) + if err != nil { + return "", err + } + arg = field + case *ast.NumericLiteral: + arg = argExpr.Value + default: return "", &TranslationError{ Code: http.StatusBadRequest, - Message: fmt.Sprintf("translator: %s only supports identifiers", strings.ToLower(name)), + Message: fmt.Sprintf("translator: %s only supports identifiers or numeric literals", strings.ToLower(name)), } } - field, err := v.normalizeIdentifier(ident) - if err != nil { - return "", err - } - arg = field default: return "", &TranslationError{ Code: http.StatusBadRequest, @@ -2510,7 +2800,7 @@ func (v *selectTranslatorVisitor) translateExpr(expr ast.Expr) (string, error) { return "", err } if name, ok := v.aggResults[key]; ok { - return name, nil + return formatFieldName(name), nil } } return "", &TranslationError{ @@ -2887,7 +3177,7 @@ func (v *selectTranslatorVisitor) fieldNameFromExpr(expr ast.Expr) (string, bool Message: "translator: unknown aggregate referenced", } } - return name, true, nil + return formatFieldName(name), true, nil } if groupField, ok, err := v.lookupGroupExpr(e); err != nil { return "", false, err diff --git a/lib/logsql/select_test.go b/lib/logsql/select_test.go index c0fb353..4aae2d0 100644 --- a/lib/logsql/select_test.go +++ b/lib/logsql/select_test.go @@ -126,6 +126,41 @@ func TestToLogsQLSuccess(t *testing.T) { sql: "SELECT * FROM logs OFFSET 3", expected: "* | offset 3", }, + { + name: "select literal without from", + sql: "SELECT 1", + expected: "* | limit 1 | delete * | format 1", + }, + { + name: "select literal with alias without from", + sql: "SELECT 1 AS one", + expected: "* | limit 1 | delete * | format 1 as one | fields one", + }, + { + name: "select literal with from", + sql: "SELECT 1 FROM logs", + expected: "* | format 1 as literal_1 | fields literal_1", + }, + { + name: "select string literal without from", + sql: "SELECT 'hello'", + expected: "* | limit 1 | delete * | format \"hello\"", + }, + { + name: "select string literal with alias without from", + sql: "SELECT 'hello' AS greeting", + expected: "* | limit 1 | delete * | format \"hello\" as greeting | fields greeting", + }, + { + name: "select string literal with from", + sql: "SELECT 'hello' FROM logs", + expected: "* | format \"hello\" as literal_hello | fields literal_hello", + }, + { + name: "select qualified star", + sql: "SELECT l.* FROM logs l", + expected: "*", + }, { name: "in list", sql: "SELECT * FROM logs WHERE service IN ('api', 'worker')", @@ -176,6 +211,16 @@ func TestToLogsQLSuccess(t *testing.T) { sql: "SELECT COUNT(*) FROM logs", expected: "* | stats count()", }, + { + name: "count numeric literal", + sql: "SELECT COUNT(1) FROM logs", + expected: "* | format 1 as __const_1 | stats count(__const_1)", + }, + { + name: "sum numeric literal", + sql: "SELECT SUM(1) FROM logs", + expected: "* | format 1 as __const_1 | stats sum(__const_1)", + }, { name: "trim function", sql: "SELECT TRIM(message) AS trimmed FROM logs", @@ -306,6 +351,16 @@ func TestToLogsQLSuccess(t *testing.T) { sql: "SELECT COUNT(*) OVER (ORDER BY _time) AS running_count FROM logs", expected: "* | sort by (_time) | running_stats count() as running_count | fields running_count", }, + { + name: "window count numeric literal", + sql: "SELECT COUNT(1) OVER (ORDER BY _time) AS running_count FROM logs", + expected: "* | sort by (_time) | format 1 as __const_1 | running_stats count(__const_1) as running_count | fields running_count", + }, + { + name: "window sum numeric literal", + sql: "SELECT SUM(1) OVER (ORDER BY _time) AS running_total FROM logs", + expected: "* | sort by (_time) | format 1 as __const_1 | running_stats sum(__const_1) as running_total | fields running_total", + }, { name: "ceil function", sql: "SELECT CEIL(duration_ms / 1000.0) AS duration FROM logs", @@ -343,6 +398,27 @@ SELECT * FROM logs WHERE level = 'warn'`, sql: "SELECT level, COUNT(*) AS total FROM logs GROUP BY level HAVING COUNT(*) > 10", expected: "* | stats by (level) count() total | filter total:>10", }, + { + name: "group by count numeric literal", + sql: "SELECT service, COUNT(1) AS total FROM logs GROUP BY service", + expected: "* | format 1 as __const_1 | stats by (service) count(__const_1) total", + }, + { + name: "group by sum numeric literal", + sql: "SELECT service, SUM(1) AS total FROM logs GROUP BY service", + expected: "* | format 1 as __const_1 | stats by (service) sum(__const_1) total", + }, + { + name: "group by without aggregates", + sql: "SELECT kubernetes.container_name FROM logs GROUP BY kubernetes.container_name", + expected: "* | uniq by (kubernetes.container_name) | fields kubernetes.container_name", + }, + { + name: "having aggregate constant", + sql: "SELECT SUM(1) AS \"cnt_slack_079B451E84304DF1AAA4188E26F02806_ok\" FROM logs HAVING COUNT(1) > 0", + expected: "* | format 1 as __const_1 | stats sum(__const_1) cnt_slack_079B451E84304DF1AAA4188E26F02806_ok, count(__const_1) | " + + "filter \"count(__const_1)\":>0 | delete \"count(__const_1)\"", + }, { name: "with simple cte", sql: `WITH recent_errors AS ( diff --git a/lib/sql/parser/parser.go b/lib/sql/parser/parser.go index 1a2ae00..d93423c 100644 --- a/lib/sql/parser/parser.go +++ b/lib/sql/parser/parser.go @@ -21,8 +21,9 @@ type Parser struct { l *lexer.Lexer errors []error - curToken token.Token - peekToken token.Token + curToken token.Token + peekToken token.Token + peekToken2 token.Token depth int // Current recursion depth } @@ -30,7 +31,9 @@ type Parser struct { // New returns a parser over the provided lexer. func New(l *lexer.Lexer) *Parser { p := &Parser{l: l, errors: make([]error, 0)} - p.nextToken() + p.curToken = token.Token{} + p.peekToken = p.l.NextToken() + p.peekToken2 = p.l.NextToken() p.nextToken() return p } @@ -47,11 +50,13 @@ func (p *Parser) addError(pos token.Position, format string, args ...interface{} func (p *Parser) nextToken() { p.curToken = p.peekToken - p.peekToken = p.l.NextToken() + p.peekToken = p.peekToken2 + p.peekToken2 = p.l.NextToken() } -func (p *Parser) curTokenIs(t token.Type) bool { return p.curToken.Type == t } -func (p *Parser) peekTokenIs(t token.Type) bool { return p.peekToken.Type == t } +func (p *Parser) curTokenIs(t token.Type) bool { return p.curToken.Type == t } +func (p *Parser) peekTokenIs(t token.Type) bool { return p.peekToken.Type == t } +func (p *Parser) peekPeekTokenIs(t token.Type) bool { return p.peekToken2.Type == t } func (p *Parser) expectPeek(t token.Type) bool { if p.peekTokenIs(t) { @@ -655,6 +660,9 @@ func (p *Parser) parseIdentifier() *ast.Identifier { func (p *Parser) parseQualifiedName() *ast.Identifier { parts := []string{p.curToken.Literal} for p.peekTokenIs(token.DOT) { + if p.peekPeekTokenIs(token.STAR) { + break + } p.nextToken() if !p.expectPeek(token.IDENT) { return &ast.Identifier{Parts: parts}