diff --git a/pkg/scanner/ast/tree/tree.go b/pkg/scanner/ast/tree/tree.go index 0a75f831b..f1a0d4996 100644 --- a/pkg/scanner/ast/tree/tree.go +++ b/pkg/scanner/ast/tree/tree.go @@ -250,7 +250,7 @@ func nodeListToID(nodes []*Node) []int { // FIXME: remove this func (node *Node) EachContentPart(onText func(text string) error, onChild func(child *Node) error) error { start := node.ContentStart.Byte - end := start + end := node.ContentEnd.Byte emit := func() error { if end <= start { @@ -274,7 +274,7 @@ func (node *Node) EachContentPart(onText func(text string) error, onChild func(c } start = child.ContentEnd.Byte - end = start + end = node.ContentEnd.Byte } if err := emit(); err != nil { diff --git a/pkg/scanner/ast/tree/tree_test.go b/pkg/scanner/ast/tree/tree_test.go index 8ee555c8e..b94d85064 100644 --- a/pkg/scanner/ast/tree/tree_test.go +++ b/pkg/scanner/ast/tree/tree_test.go @@ -4,16 +4,18 @@ import ( "context" "testing" - "github.com/bradleyjkemp/cupaloy" sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/javascript" "github.com/smacker/go-tree-sitter/ruby" "github.com/bearer/bearer/pkg/scanner/ast/tree" + + "github.com/bradleyjkemp/cupaloy" + "github.com/stretchr/testify/assert" ) -func parseTree(t *testing.T, content string) *tree.Tree { +func parseTree(t *testing.T, sitterLanguage *sitter.Language, content string) *tree.Tree { contentBytes := []byte(content) - sitterLanguage := ruby.GetLanguage() sitterRootNode, err := sitter.ParseCtx(context.Background(), contentBytes, sitterLanguage) if err != nil { @@ -24,7 +26,7 @@ func parseTree(t *testing.T, content string) *tree.Tree { } func TestTree(t *testing.T) { - tree := parseTree(t, ` + tree := parseTree(t, ruby.GetLanguage(), ` def m(a) a.foo end @@ -32,3 +34,30 @@ func TestTree(t *testing.T) { cupaloy.SnapshotT(t, tree.RootNode().Dump()) } + +func TestContentParts(t *testing.T) { + for _, test := range []struct{ expression, expected string }{ + {"`abc`", "abc"}, + {"`a${b}c`", "a*c"}, + {"`${b}c`", "*c"}, + {"`a${b}`", "a*"}, + } { + t.Run(test.expression, func(tt *testing.T) { + ast := parseTree(tt, javascript.GetLanguage(), test.expression) + stringNode := ast.RootNode().NamedChildren()[0].NamedChildren()[0] + assert.Equal(tt, "template_string", stringNode.Type()) + + var result string + err := stringNode.EachContentPart(func(text string) error { + result += text + return nil + }, func(child *tree.Node) error { + result += "*" + return nil + }) + assert.NoError(tt, err) + + assert.Equal(tt, test.expected, result) + }) + } +}