diff --git a/.version b/.version index f94611877..1852307b0 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -0.2.610 \ No newline at end of file +0.2.611 \ No newline at end of file diff --git a/parser/v2/goexpression/fuzz.sh b/parser/v2/goexpression/fuzz.sh index 0119459a6..20062f9ed 100755 --- a/parser/v2/goexpression/fuzz.sh +++ b/parser/v2/goexpression/fuzz.sh @@ -8,6 +8,8 @@ echo Case go test -fuzz=FuzzCaseStandard -fuzztime=120s echo Default go test -fuzz=FuzzCaseDefault -fuzztime=120s +echo TemplExpression +go test -fuzz=FuzzTemplExpression -fuzztime=120s echo Expression go test -fuzz=FuzzExpression -fuzztime=120s echo SliceArgs diff --git a/parser/v2/goexpression/parse.go b/parser/v2/goexpression/parse.go index 86f686531..77c413cc9 100644 --- a/parser/v2/goexpression/parse.go +++ b/parser/v2/goexpression/parse.go @@ -2,11 +2,12 @@ package goexpression import ( "errors" + "fmt" "go/ast" "go/parser" + "go/scanner" "go/token" "regexp" - "slices" "strings" "unicode" ) @@ -104,56 +105,91 @@ func Switch(content string) (start, end int, err error) { }) } -var boundaryTokens = map[rune]any{ - ')': nil, - '}': nil, - ']': nil, - ' ': nil, - '\t': nil, - '\n': nil, - '@': nil, - '<': nil, - '+': nil, - '.': nil, -} +func TemplExpression(src string) (start, end int, err error) { + var s scanner.Scanner + fset := token.NewFileSet() + file := fset.AddFile("", fset.Base(), len(src)) + errorHandler := func(pos token.Position, msg string) { + err = fmt.Errorf("error parsing expression: %v", msg) + } + s.Init(file, []byte(src), errorHandler, 0) -func allBoundaries(content string) (boundaries []int) { - for i, r := range content { - if _, ok := boundaryTokens[r]; ok { - boundaries = append(boundaries, i) - boundaries = append(boundaries, i+1) + // Read chains of identifiers, e.g.: + // components.Variable + // components[0].Variable + // components["name"].Function() + // functionCall(withLots(), func() { return true }) + ep := NewExpressionParser() + for { + pos, tok, lit := s.Scan() + stop, err := ep.Insert(pos, tok, lit) + if err != nil { + return 0, 0, err + } + if stop { + break } } - boundaries = append(boundaries, len(content)) - return + return 0, ep.End, nil } -func Expression(content string) (start, end int, err error) { - var candidates []int - for _, to := range allBoundaries(content) { - expr, err := parser.ParseExpr(content[:to]) - if err != nil { - continue - } - switch expr := expr.(type) { - case *ast.CallExpr: - end = int(expr.Rparen) - default: - end = int(expr.End()) - 1 +func Expression(src string) (start, end int, err error) { + var s scanner.Scanner + fset := token.NewFileSet() + file := fset.AddFile("", fset.Base(), len(src)) + errorHandler := func(pos token.Position, msg string) { + err = fmt.Errorf("error parsing expression: %v", msg) + } + s.Init(file, []byte(src), errorHandler, 0) + + // Read chains of identifiers and constants up until RBRACE, e.g.: + // true + // 123.45 == true + // components.Variable + // components[0].Variable + // components["name"].Function() + // functionCall(withLots(), func() { return true }) + // !true + parenDepth := 0 + bracketDepth := 0 + braceDepth := 0 +loop: + for { + pos, tok, lit := s.Scan() + if tok == token.EOF { + break loop } - // If the expression ends with `...` then it's a child spread expression. - if end < len(content) { - if strings.HasPrefix(content[end:], "...") { - end += len("...") + switch tok { + case token.LPAREN: // ( + parenDepth++ + case token.RPAREN: // ) + end = int(pos) + parenDepth-- + case token.LBRACK: // [ + bracketDepth++ + case token.RBRACK: // ] + end = int(pos) + bracketDepth-- + case token.LBRACE: // { + braceDepth++ + case token.RBRACE: // } + braceDepth-- + if braceDepth < 0 { + // We've hit the end of the expression. + break loop } + end = int(pos) + case token.IDENT, token.INT, token.FLOAT, token.IMAG, token.CHAR, token.STRING: + end = int(pos) + len(lit) - 1 + case token.SEMICOLON: + continue + case token.ILLEGAL: + return 0, 0, fmt.Errorf("illegal token: %v", lit) + default: + end = int(pos) + len(tok.String()) - 1 } - candidates = append(candidates, end) - } - if len(candidates) == 0 { - return 0, 0, ErrExpectedNodeNotFound } - slices.Sort(candidates) - return 0, candidates[len(candidates)-1], nil + return start, end, nil } func SliceArgs(content string) (expr string, err error) { diff --git a/parser/v2/goexpression/parse_test.go b/parser/v2/goexpression/parse_test.go index 765c96e1d..84aa25913 100644 --- a/parser/v2/goexpression/parse_test.go +++ b/parser/v2/goexpression/parse_test.go @@ -295,6 +295,22 @@ func FuzzCaseDefault(f *testing.F) { } var expressionTests = []testInput{ + { + name: "string literal", + input: `"hello"`, + }, + { + name: "string literal with escape", + input: `"hello\n"`, + }, + { + name: "backtick string literal", + input: "`hello`", + }, + { + name: "backtick string literal containing double quote", + input: "`hello" + `"` + `world` + "`", + }, { name: "function call in package", input: `components.Other()`, @@ -347,8 +363,6 @@ func TestExpression(t *testing.T) { "}", "\t}", " }", - "", - "

/

", } for _, test := range expressionTests { for i, suffix := range suffixes { @@ -357,6 +371,94 @@ func TestExpression(t *testing.T) { } } +var templExpressionTests = []testInput{ + { + name: "function call in package", + input: `components.Other()`, + }, + { + name: "slice index call", + input: `components[0].Other()`, + }, + { + name: "map index function call", + input: `components["name"].Other()`, + }, + { + name: "map index function call backtick literal", + input: "components[`name" + `"` + "`].Other()", + }, + { + name: "function literal", + input: `components["name"].Other(func() bool { return true })`, + }, + { + name: "multiline function call", + input: `component(map[string]string{ + "namea": "name_a", + "nameb": "name_b", + })`, + }, + { + name: "call with braces and brackets", + input: `templates.New(test{}, other())`, + }, + { + name: "struct method call", + input: `typeName{}.Method()`, + }, + { + name: "struct method call in other package", + input: "layout.DefaultLayout{}.Compile()", + }, + { + name: "bare variable", + input: `component`, + }, +} + +func TestTemplExpression(t *testing.T) { + prefix := "" + suffixes := []string{ + "", + "}", + "\t}", + " }", + "", + "

/

", + " just some text", + " {
Child content
}", + } + for _, test := range templExpressionTests { + for i, suffix := range suffixes { + t.Run(fmt.Sprintf("%s_%d", test.name, i), run(test, prefix, suffix, TemplExpression)) + } + } +} + +func FuzzTemplExpression(f *testing.F) { + suffixes := []string{ + "", + " }", + " }}\n}", + "...", + } + for _, test := range expressionTests { + for _, suffix := range suffixes { + f.Add(test.input + suffix) + } + } + f.Fuzz(func(t *testing.T, s string) { + src := "switch " + s + start, end, err := TemplExpression(src) + if err != nil { + t.Skip() + return + } + panicIfInvalid(src, start, end) + }) +} + func FuzzExpression(f *testing.F) { suffixes := []string{ "", diff --git a/parser/v2/goexpression/parsebench_test.go b/parser/v2/goexpression/parsebench_test.go new file mode 100644 index 000000000..8cc03253b --- /dev/null +++ b/parser/v2/goexpression/parsebench_test.go @@ -0,0 +1,105 @@ +package goexpression + +import "testing" + +var testStringExpression = `"this string expression" } +
+ But afterwards, it keeps searching. +
+ +
+ But that's not right, we can stop searching. It won't find anything valid. +
+ +
+ Certainly not later in the file. +
+ +
+ It's going to try all the tokens. + )}]@<+. +
+ +
+ It's going to try all the tokens. + )}]@<+. +
+ +
+ It's going to try all the tokens. + )}]@<+. +
+ +
+ It's going to try all the tokens. + )}]@<+. +
+` + +func BenchmarkExpression(b *testing.B) { + // Baseline... + // BenchmarkExpression-10 6484 184862 ns/op + // Updated... + // BenchmarkExpression-10 3942538 279.6 ns/op + for n := 0; n < b.N; n++ { + start, end, err := Expression(testStringExpression) + if err != nil { + b.Fatal(err) + } + if start != 0 || end != 24 { + b.Fatalf("expected 0, 24, got %d, %d", start, end) + } + } +} + +var testTemplExpression = `templates.CallMethod(map[string]any{ + "name": "this string expression", +}) + +
+ But afterwards, it keeps searching. +
+ +
+ But that's not right, we can stop searching. It won't find anything valid. +
+ +
+ Certainly not later in the file. +
+ +
+ It's going to try all the tokens. + )}]@<+. +
+ +
+ It's going to try all the tokens. + )}]@<+. +
+ +
+ It's going to try all the tokens. + )}]@<+. +
+ +
+ It's going to try all the tokens. + )}]@<+. +
+` + +func BenchmarkTemplExpression(b *testing.B) { + // BenchmarkTemplExpression-10 2694 431934 ns/op + // Updated... + // BenchmarkTemplExpression-10 1339399 897.6 ns/op + for n := 0; n < b.N; n++ { + start, end, err := TemplExpression(testTemplExpression) + if err != nil { + b.Fatal(err) + } + if start != 0 || end != 74 { + b.Fatalf("expected 0, 74, got %d, %d", start, end) + } + } +} diff --git a/parser/v2/goexpression/scanner.go b/parser/v2/goexpression/scanner.go new file mode 100644 index 000000000..56c044451 --- /dev/null +++ b/parser/v2/goexpression/scanner.go @@ -0,0 +1,149 @@ +package goexpression + +import ( + "fmt" + "go/token" +) + +type Stack[T any] []T + +func (s *Stack[T]) Push(v T) { + *s = append(*s, v) +} + +func (s *Stack[T]) Pop() (v T) { + if len(*s) == 0 { + return v + } + v = (*s)[len(*s)-1] + *s = (*s)[:len(*s)-1] + return v +} + +func (s *Stack[T]) Peek() (v T) { + if len(*s) == 0 { + return v + } + return (*s)[len(*s)-1] +} + +var goTokenOpenToClose = map[token.Token]token.Token{ + token.LPAREN: token.RPAREN, + token.LBRACE: token.RBRACE, + token.LBRACK: token.RBRACK, +} + +var goTokenCloseToOpen = map[token.Token]token.Token{ + token.RPAREN: token.LPAREN, + token.RBRACE: token.LBRACE, + token.RBRACK: token.LBRACK, +} + +type ErrUnbalanced struct { + Token token.Token +} + +func (e ErrUnbalanced) Error() string { + return fmt.Sprintf("unbalanced '%s'", e.Token) +} + +func NewExpressionParser() *ExpressionParser { + return &ExpressionParser{ + Stack: make(Stack[token.Token], 0), + Previous: token.PERIOD, + Fns: make(Stack[int], 0), + } +} + +type ExpressionParser struct { + Stack Stack[token.Token] + End int + Previous token.Token + Fns Stack[int] // Stack of function depths. +} + +func (ep *ExpressionParser) Insert(pos token.Pos, tok token.Token, lit string) (stop bool, err error) { + defer func() { + ep.Previous = tok + }() + if tok == token.FUNC { + // The next open brace will be the body of a function literal, so push the fn depth. + ep.Fns.Push(len(ep.Stack)) + ep.End = int(pos) + len(tokenString(tok, lit)) - 1 + return false, nil + } + // Opening a pair can be done after an ident, but it can also be a func literal. + // e.g. "name()", or "name(func() bool { return true })". + if _, ok := goTokenOpenToClose[tok]; ok { + if tok == token.LBRACE { + if ep.Previous != token.IDENT { + return true, nil + } + hasSpace := (int(pos) - 1) > ep.End + if hasSpace && len(ep.Fns) == 0 { + // There's a space, and we're not in a function so stop. + return true, nil + } + } + ep.Stack.Push(tok) + ep.End = int(pos) + len(tokenString(tok, lit)) - 1 + return false, nil + } + // Closing a pair. + if expected, ok := goTokenCloseToOpen[tok]; ok { + if len(ep.Stack) == 0 { + // We've got a close token, but there's nothing to close, so we must be done. + return true, nil + } + actual := ep.Stack.Pop() + if !ok { + return false, ErrUnbalanced{tok} + } + if actual != expected { + return false, ErrUnbalanced{tok} + } + // If we're closing a function, pop the function depth. + if tok == token.RBRACE && len(ep.Stack) == ep.Fns.Peek() { + ep.Fns.Pop() + } + ep.End = int(pos) + len(tokenString(tok, lit)) - 1 + return false, nil + } + // If we're within a pair, we allow anything. + if len(ep.Stack) > 0 { + ep.End = int(pos) + len(tokenString(tok, lit)) - 1 + return false, nil + } + // We allow a period to follow an ident or a closer. + // e.g. "package.name" or "typeName{field: value}.name()". + if tok == token.PERIOD && (ep.Previous == token.IDENT || isCloser(ep.Previous)) { + ep.End = int(pos) + len(tokenString(tok, lit)) - 1 + return false, nil + } + // We allow an ident to follow a period or a closer. + // e.g. "package.name", "typeName{field: value}.name()". + // or "call().name", "call().name()". + // But not "package .name" or "typeName{field: value} .name()". + if tok == token.IDENT && (ep.Previous == token.PERIOD || isCloser(ep.Previous)) { + if (int(pos) - 1) > ep.End { + // There's a space, so stop. + return true, nil + } + ep.End = int(pos) + len(tokenString(tok, lit)) - 1 + return false, nil + } + // Anything else returns stop=true. + return true, nil +} + +func tokenString(tok token.Token, lit string) string { + if tok.IsKeyword() || tok.IsOperator() { + return tok.String() + } + return lit +} + +func isCloser(tok token.Token) bool { + _, ok := goTokenCloseToOpen[tok] + return ok +} diff --git a/parser/v2/goexpression/testdata/fuzz/FuzzExpression/ac5d99902f5e7914 b/parser/v2/goexpression/testdata/fuzz/FuzzExpression/ac5d99902f5e7914 new file mode 100644 index 000000000..ebfdb088d --- /dev/null +++ b/parser/v2/goexpression/testdata/fuzz/FuzzExpression/ac5d99902f5e7914 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("#") diff --git a/parser/v2/templelementparser.go b/parser/v2/templelementparser.go index 56239cd0e..9a50b5b03 100644 --- a/parser/v2/templelementparser.go +++ b/parser/v2/templelementparser.go @@ -15,7 +15,7 @@ func (p templElementExpressionParser) Parse(pi *parse.Input) (n Node, ok bool, e var r TemplElementExpression // Parse the Go expression. - if r.Expression, err = parseGo("templ element", pi, goexpression.Expression); err != nil { + if r.Expression, err = parseGo("templ element", pi, goexpression.TemplExpression); err != nil { return r, false, err }