Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions internal/cmd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ import (
"github.com/authzed/zed/internal/console"
)

type termChecker interface {
IsTerminal(fd int) bool
}

type realTermChecker struct{}

func (rtc *realTermChecker) IsTerminal(fd int) bool {
return term.IsTerminal(fd)
}

func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {
schemaCmd.AddCommand(schemaCopyCmd)
schemaCopyCmd.Flags().Bool("json", false, "output as JSON")
Expand All @@ -50,7 +60,19 @@ var schemaWriteCmd = &cobra.Command{
Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
Short: "Write a schema file (.zed or stdin) to the current permissions system",
ValidArgsFunction: commands.FileExtensionCompletions("zed"),
RunE: schemaWriteCmdFunc,
Example: `
Write from a file:
zed schema write schema.zed
Write from stdin:
cat schema.zed | zed schema write
`,
RunE: func(cmd *cobra.Command, args []string) error {
client, err := client.NewClient(cmd)
if err != nil {
return err
}
return schemaWriteCmdImpl(cmd, args, client, &realTermChecker{})
},
}

var schemaCopyCmd = &cobra.Command{
Expand Down Expand Up @@ -79,7 +101,9 @@ var schemaCompileCmd = &cobra.Command{
zed preview schema compile root.zed --out compiled.zed
`,
ValidArgsFunction: commands.FileExtensionCompletions("zed"),
RunE: schemaCompileCmdFunc,
RunE: func(cmd *cobra.Command, args []string) error {
return schemaCompileCmdFunc(cmd, args, &realTermChecker{})
},
}

func schemaDiffCmdFunc(_ *cobra.Command, args []string) error {
Expand Down Expand Up @@ -196,19 +220,16 @@ func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error {
return nil
}

func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error {
intFd, err := safecast.ToInt(uint(os.Stdout.Fd()))
func schemaWriteCmdImpl(cmd *cobra.Command, args []string, client v1.SchemaServiceClient, terminalChecker termChecker) error {
stdInFd, err := safecast.ToInt(uint(os.Stdin.Fd()))
if err != nil {
return err
}
if len(args) == 0 && term.IsTerminal(intFd) {
return fmt.Errorf("must provide file path or contents via stdin")
}

client, err := client.NewClient(cmd)
if err != nil {
return err
if len(args) == 0 && terminalChecker.IsTerminal(stdInFd) {
return errors.New("must provide file path or contents via stdin")
}

var schemaBytes []byte
switch len(args) {
case 1:
Expand Down Expand Up @@ -246,18 +267,17 @@ func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error {

resp, err := client.WriteSchema(cmd.Context(), request)
if err != nil {
log.Fatal().Err(err).Msg("failed to write schema")
return fmt.Errorf("failed to write schema: %w", err)
}
log.Trace().Interface("response", resp).Msg("wrote schema")

if cobrautil.MustGetBool(cmd, "json") {
prettyProto, err := commands.PrettyProto(resp)
if err != nil {
log.Fatal().Err(err).Msg("failed to convert schema to JSON")
return fmt.Errorf("failed to convert schema to JSON: %w", err)
}

console.Println(string(prettyProto))
return nil
}

return nil
Expand Down Expand Up @@ -287,7 +307,7 @@ func rewriteSchema(existingSchemaText string, definitionPrefix string) (string,
// If specifiedPrefix is non-empty, it is returned immediately.
// If existingSchema is non-nil, it is parsed for the prefix.
// Otherwise, the client is used to retrieve the existing schema (if any), and the prefix is retrieved from there.
func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client client.Client, existingSchema *string) (string, error) {
func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client v1.SchemaServiceClient, existingSchema *string) (string, error) {
if specifiedPrefix != "" {
return specifiedPrefix, nil
}
Expand Down Expand Up @@ -340,14 +360,14 @@ func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, clien

// Compiles an input schema written in the new composable schema syntax
// and produces it as a fully-realized schema
func schemaCompileCmdFunc(cmd *cobra.Command, args []string) error {
func schemaCompileCmdFunc(cmd *cobra.Command, args []string, termChecker termChecker) error {
stdOutFd, err := safecast.ToInt(uint(os.Stdout.Fd()))
if err != nil {
return err
}
outputFilepath := cobrautil.MustGetString(cmd, "out")
if outputFilepath == "" && !term.IsTerminal(stdOutFd) {
return fmt.Errorf("must provide stdout or output file path")
if outputFilepath == "" && !termChecker.IsTerminal(stdOutFd) {
return errors.New("must provide stdout or output file path")
}

inputFilepath := args[0]
Expand Down
174 changes: 173 additions & 1 deletion internal/cmd/schema_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package cmd

import (
"context"
"errors"
"io/fs"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/grpc"

v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/spicedb/pkg/composableschemadsl/compiler"

zedtesting "github.com/authzed/zed/internal/testing"
Expand Down Expand Up @@ -166,16 +170,184 @@ definition resource {
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "out", FlagValue: tempOutFile})

err := schemaCompileCmdFunc(cmd, tc.files)
mockTermCheckerr := &mockTermChecker{returnVal: false}
err := schemaCompileCmdFunc(cmd, tc.files, mockTermCheckerr)
if tc.expectErr == nil {
require.NoError(err)
tempOutString, err := os.ReadFile(tempOutFile)
require.NoError(err)
require.Equal(tc.expectStr, string(tempOutString))
// TODO re-enable after adding a test that uses stdout
// require.Equal(int(os.Stdout.Fd()), mockTermCheckerr.capturedFd, "expected stdout to be checked for terminal")
} else {
require.Error(err)
require.ErrorAs(err, &tc.expectErr)
}
})
}
}

func TestSchemaWrite(t *testing.T) {
t.Parallel()

// Save original stdin
oldStdin := os.Stdin
t.Cleanup(func() {
os.Stdin = oldStdin
})

testCases := map[string]struct {
schemaMakerFn func() ([]string, error)
terminalChecker *mockTermChecker
expectErr string
expectSchemaWritten string
}{
`schema_from_file`: {
schemaMakerFn: func() ([]string, error) {
return []string{
filepath.Join("write-schema-test", "basic.zed"),
}, nil
},
expectSchemaWritten: `definition user {}
definition resource {
relation view: user
permission viewer = view
}`,
terminalChecker: &mockTermChecker{returnVal: false},
},
`schema_from_stdin`: {
schemaMakerFn: func() ([]string, error) {
schemaContent := "definition user{}\ndefinition document { relation read: user }"
pipeRead, pipeWrite, err := os.Pipe()
require.NoError(t, err)
os.Stdin = pipeRead
_, err = pipeWrite.WriteString(schemaContent)
require.NoError(t, err)
err = pipeWrite.Close()
require.NoError(t, err)
return []string{}, nil
},
terminalChecker: &mockTermChecker{returnVal: false},
expectSchemaWritten: "definition user{}\ndefinition document { relation read: user }",
},
`schema_from_stdin_but_terminal`: {
schemaMakerFn: func() ([]string, error) {
schemaContent := "definition user{}\ndefinition document { relation read: user }"
pipeRead, pipeWrite, err := os.Pipe()
require.NoError(t, err)
os.Stdin = pipeRead
_, err = pipeWrite.WriteString(schemaContent)
require.NoError(t, err)
err = pipeWrite.Close()
require.NoError(t, err)
return []string{}, nil
},
terminalChecker: &mockTermChecker{returnVal: true},
expectErr: "must provide file path or contents via stdin",
},
`empty_schema_errors`: {
schemaMakerFn: func() ([]string, error) {
pipeRead, pipeWrite, err := os.Pipe()
require.NoError(t, err)
os.Stdin = pipeRead
_, err = pipeWrite.WriteString("")
require.NoError(t, err)
err = pipeWrite.Close()
require.NoError(t, err)
return []string{}, nil
},
terminalChecker: &mockTermChecker{returnVal: false},
expectErr: "attempted to write empty schema",
},
`write_failure_errors`: {
schemaMakerFn: func() ([]string, error) {
return []string{
filepath.Join("write-schema-test", "basic.zed"),
}, errors.New("write error")
},
terminalChecker: &mockTermChecker{returnVal: false},
expectErr: "error writing schema",
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t,
zedtesting.StringFlag{FlagName: "schema-definition-prefix", FlagValue: ""},
zedtesting.BoolFlag{FlagName: "json", FlagValue: true},
)

args, writeErr := tc.schemaMakerFn()
mockWriteSchemaClientt := &mockWriteSchemaClient{}
if writeErr != nil {
mockWriteSchemaClientt.writeReturnsError = true
}

err := schemaWriteCmdImpl(cmd, args, mockWriteSchemaClientt, tc.terminalChecker)

if tc.expectErr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectErr)
return
}

require.NoError(t, err)
require.Equal(t, tc.expectSchemaWritten, mockWriteSchemaClientt.receivedSchema)
if tc.terminalChecker.captured {
require.Equal(t, int(os.Stdin.Fd()), tc.terminalChecker.capturedFd, "expected stdin to be checked for terminal")
}
})
}
}

type mockWriteSchemaClient struct {
existingSchema string
receivedSchema string
writeReturnsError bool
}

var _ v1.SchemaServiceClient = (*mockWriteSchemaClient)(nil)

func (m *mockWriteSchemaClient) WriteSchema(_ context.Context, in *v1.WriteSchemaRequest, _ ...grpc.CallOption) (*v1.WriteSchemaResponse, error) {
if m.writeReturnsError {
return nil, errors.New("error writing schema")
}
m.receivedSchema = in.Schema
return &v1.WriteSchemaResponse{}, nil
}

func (m *mockWriteSchemaClient) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest, _ ...grpc.CallOption) (*v1.ReadSchemaResponse, error) {
return &v1.ReadSchemaResponse{
SchemaText: m.existingSchema,
}, nil
}

func (m *mockWriteSchemaClient) ReflectSchema(_ context.Context, _ *v1.ReflectSchemaRequest, _ ...grpc.CallOption) (*v1.ReflectSchemaResponse, error) {
panic("not implemented")
}

func (m *mockWriteSchemaClient) ComputablePermissions(_ context.Context, _ *v1.ComputablePermissionsRequest, _ ...grpc.CallOption) (*v1.ComputablePermissionsResponse, error) {
panic("not implemented")
}

func (m *mockWriteSchemaClient) DependentRelations(_ context.Context, _ *v1.DependentRelationsRequest, _ ...grpc.CallOption) (*v1.DependentRelationsResponse, error) {
panic("not implemented")
}

func (m *mockWriteSchemaClient) DiffSchema(_ context.Context, _ *v1.DiffSchemaRequest, _ ...grpc.CallOption) (*v1.DiffSchemaResponse, error) {
panic("not implemented")
}

type mockTermChecker struct {
returnVal bool
captured bool
capturedFd int
}

var _ termChecker = (*mockTermChecker)(nil)

func (m *mockTermChecker) IsTerminal(fd int) bool {
m.captured = true
m.capturedFd = fd
return m.returnVal
}
5 changes: 5 additions & 0 deletions internal/cmd/write-schema-test/basic.zed
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
definition user {}
definition resource {
relation view: user
permission viewer = view
}
2 changes: 1 addition & 1 deletion internal/commands/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func schemaReadCmdFunc(cmd *cobra.Command, _ []string) error {
}

// ReadSchema calls read schema for the client and returns the schema found.
func ReadSchema(ctx context.Context, client client.Client) (string, error) {
func ReadSchema(ctx context.Context, client v1.SchemaServiceClient) (string, error) {
request := &v1.ReadSchemaRequest{}
log.Trace().Interface("request", request).Msg("requesting schema read")

Expand Down
Loading