Skip to content

Commit

Permalink
lsp: Add package/rule ref completions
Browse files Browse the repository at this point in the history
Signed-off-by: Charlie Egan <charlie@styra.com>
  • Loading branch information
charlieegan3 committed May 23, 2024
1 parent e21051b commit 69323c6
Show file tree
Hide file tree
Showing 16 changed files with 1,186 additions and 229 deletions.
31 changes: 31 additions & 0 deletions internal/ast/ref.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package ast

import (
"strings"

"github.com/open-policy-agent/opa/ast"
)

func RefToString(ref ast.Ref) string {
sb := strings.Builder{}

for i, part := range ref {
if part.IsGround() {
if i > 0 {
sb.WriteString(".")
}

sb.WriteString(strings.Trim(part.Value.String(), `"`))
} else {
if i == 0 {
sb.WriteString(strings.Trim(part.Value.String(), `"`))
} else {
sb.WriteString("[")
sb.WriteString(strings.Trim(part.Value.String(), `"`))
sb.WriteString("]")
}
}
}

return sb.String()
}
75 changes: 75 additions & 0 deletions internal/ast/ref_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package ast

import (
"testing"

"github.com/open-policy-agent/opa/ast"
)

func TestRefToString(t *testing.T) {
t.Parallel()

cases := []struct {
title string
ref ast.Ref
expected string
}{
{
"single var",
ast.Ref{
ast.VarTerm("foo"),
},
"foo",
},
{
"var in middle",
ast.Ref{
ast.StringTerm("foo"),
ast.VarTerm("bar"),
ast.StringTerm("baz"),
},
"foo[bar].baz",
},
{
"strings",
ast.Ref{
ast.DefaultRootDocument,
ast.StringTerm("foo"),
ast.StringTerm("bar"),
ast.StringTerm("baz"),
},
"data.foo.bar.baz",
},
{
"consecutive vars",
ast.Ref{
ast.VarTerm("foo"),
ast.VarTerm("bar"),
ast.VarTerm("baz"),
},
"foo[bar][baz]",
},
{
"mixed",
ast.Ref{
ast.VarTerm("foo"),
ast.VarTerm("bar"),
ast.StringTerm("baz"),
ast.VarTerm("qux"),
ast.StringTerm("quux"),
},
"foo[bar].baz[qux].quux",
},
}

for _, tc := range cases {
t.Run(tc.title, func(t *testing.T) {
t.Parallel()

result := RefToString(tc.ref)
if result != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, result)
}
})
}
}
102 changes: 102 additions & 0 deletions internal/ast/rule.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package ast

import (
"fmt"
"strings"

"github.com/open-policy-agent/opa/ast"

"github.com/styrainc/regal/internal/lsp/rego"
)

func GetRuleDetail(rule *ast.Rule) string {
if rule.Head.Args != nil {
return "function" + rule.Head.Args.String()
}

if rule.Head.Key != nil && rule.Head.Value == nil {
return "multi-value rule"
}

if rule.Head.Value == nil {
return ""
}

detail := "single-value "

if rule.Head.Key != nil {
detail += "map "
}

detail += "rule"

switch v := rule.Head.Value.Value.(type) {
case ast.Boolean:
if strings.HasPrefix(rule.Head.Ref()[0].String(), "test_") {
detail += " (test)"
} else {
detail += " (boolean)"
}
case ast.Number:
detail += " (number)"
case ast.String:
detail += " (string)"
case *ast.Array, *ast.ArrayComprehension:
detail += " (array)"
case ast.Object, *ast.ObjectComprehension:
detail += " (object)"
case ast.Set, *ast.SetComprehension:
detail += " (set)"
case ast.Call:
name := v[0].String()

if builtin, ok := rego.BuiltIns[name]; ok {
retType := builtin.Decl.NamedResult().String()

detail += fmt.Sprintf(" (%s)", simplifyType(retType))
}
}

return detail
}

// IsConstant returns true if the rule is a "constant" rule, i.e.
// one without conditions and scalar value in the head.
func IsConstant(rule *ast.Rule) bool {
isScalar := false

if rule.Head.Value == nil {
return false
}

switch rule.Head.Value.Value.(type) {
case ast.Boolean, ast.Number, ast.String, ast.Null:
isScalar = true
}

return isScalar &&
rule.Head.Args == nil &&
rule.Body.Equal(ast.NewBody(ast.NewExpr(ast.BooleanTerm(true)))) &&
rule.Else == nil
}

// simplifyType removes anything but the base type from the type name.
func simplifyType(name string) string {
result := name

if strings.Contains(result, ":") {
result = result[strings.Index(result, ":")+1:]
}

// silence gocritic linter here as strings.Index can in
// fact *not* return -1 in these cases
if strings.Contains(result, "[") {
result = result[:strings.Index(result, "[")] //nolint:gocritic
}

if strings.Contains(result, "<") {
result = result[:strings.Index(result, "<")] //nolint:gocritic
}

return strings.TrimSpace(result)
}
93 changes: 93 additions & 0 deletions internal/ast/rule_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package ast

import (
"testing"

"github.com/styrainc/regal/internal/parse"
)

func TestGetRuleDetail(t *testing.T) {
t.Parallel()

cases := []struct {
input string
expected string
}{
{
`allow := true`,
`single-value rule (boolean)`,
},
{
`allow := [1,2,3]`,
`single-value rule (array)`,
},
{
`allow := "foo"`,
`single-value rule (string)`,
},
{
`foo contains 1 if true`,
`multi-value rule`,
},
{
`func(x) := true`,
`function(x)`,
},
}

for _, tc := range cases {
t.Run(tc.input, func(t *testing.T) {
t.Parallel()

mod := parse.MustParseModule("package example\nimport rego.v1\n" + tc.input)

if len(mod.Rules) != 1 {
t.Fatalf("Expected 1 rule, got %d", len(mod.Rules))
}

rule := mod.Rules[0]

result := GetRuleDetail(rule)
if result != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, result)
}
})
}
}

func TestSimplifyType(t *testing.T) {
t.Parallel()

cases := []struct {
input string
expected string
}{
{
"set",
"set",
},
{
"set[any]",
"set",
},
{
"any<set, object>",
"any",
},
{
"output: any<set[any], object>",
"any",
},
}

for _, tc := range cases {
t.Run(tc.input, func(t *testing.T) {
t.Parallel()

result := simplifyType(tc.input)
if result != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, result)
}
})
}
}
28 changes: 28 additions & 0 deletions internal/lsp/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@ type Cache struct {
diagnosticsParseErrors map[string][]types.Diagnostic
diagnosticsParseMu sync.Mutex

// builtinPositionsFile is a map of file URI to builtin positions for that file
builtinPositionsFile map[string]map[uint][]types.BuiltinPosition
builtinPositionsMu sync.Mutex

// fileRefs is a map of file URI to completion items that might be suggested from that file
fileRefs map[string]map[string]types.Ref
fileRefMu sync.Mutex
}

func NewCache() *Cache {
Expand All @@ -47,6 +52,8 @@ func NewCache() *Cache {
diagnosticsParseErrors: make(map[string][]types.Diagnostic),

builtinPositionsFile: make(map[string]map[uint][]types.BuiltinPosition),

fileRefs: make(map[string]map[string]types.Ref),
}
}

Expand Down Expand Up @@ -202,6 +209,27 @@ func (c *Cache) GetAllBuiltInPositions() map[string]map[uint][]types.BuiltinPosi
return c.builtinPositionsFile
}

func (c *Cache) SetFileRefs(uri string, items map[string]types.Ref) {
c.fileRefMu.Lock()
defer c.fileRefMu.Unlock()

c.fileRefs[uri] = items
}

func (c *Cache) GetFileRefs(uri string) map[string]types.Ref {
c.fileRefMu.Lock()
defer c.fileRefMu.Unlock()

return c.fileRefs[uri]
}

func (c *Cache) GetAllFileRefs() map[string]map[string]types.Ref {
c.fileRefMu.Lock()
defer c.fileRefMu.Unlock()

return c.fileRefs
}

// Delete removes all cached data for a given URI.
func (c *Cache) Delete(uri string) {
c.fileContentsMu.Lock()
Expand Down
2 changes: 2 additions & 0 deletions internal/lsp/completions/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ func NewDefaultManager(c *cache.Cache) *Manager {
m.RegisterProvider(&providers.PackageName{})
m.RegisterProvider(&providers.BuiltIns{})
m.RegisterProvider(&providers.RegoV1{})
m.RegisterProvider(&providers.PackageRefs{})
m.RegisterProvider(&providers.RuleFromImportedPackageRefs{})

return m
}
Expand Down
Loading

0 comments on commit 69323c6

Please sign in to comment.