diff --git a/internal/cmd/schema.go b/internal/cmd/schema.go index 4a14ec0..27dbf33 100644 --- a/internal/cmd/schema.go +++ b/internal/cmd/schema.go @@ -30,6 +30,16 @@ import ( "github.com/authzed/zed/internal/console" ) +type termChecker interface { + IsTerminal(fd int) bool +} + +type realTermChecker struct{} + +func (rtc *realTermChecker) IsTerminal(fd int) bool { + return term.IsTerminal(fd) +} + func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) { schemaCmd.AddCommand(schemaCopyCmd) schemaCopyCmd.Flags().Bool("json", false, "output as JSON") @@ -50,7 +60,19 @@ var schemaWriteCmd = &cobra.Command{ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), Short: "Write a schema file (.zed or stdin) to the current permissions system", ValidArgsFunction: commands.FileExtensionCompletions("zed"), - RunE: schemaWriteCmdFunc, + Example: ` + Write from a file: + zed schema write schema.zed + Write from stdin: + cat schema.zed | zed schema write +`, + RunE: func(cmd *cobra.Command, args []string) error { + client, err := client.NewClient(cmd) + if err != nil { + return err + } + return schemaWriteCmdImpl(cmd, args, client, &realTermChecker{}) + }, } var schemaCopyCmd = &cobra.Command{ @@ -79,7 +101,9 @@ var schemaCompileCmd = &cobra.Command{ zed preview schema compile root.zed --out compiled.zed `, ValidArgsFunction: commands.FileExtensionCompletions("zed"), - RunE: schemaCompileCmdFunc, + RunE: func(cmd *cobra.Command, args []string) error { + return schemaCompileCmdFunc(cmd, args, &realTermChecker{}) + }, } func schemaDiffCmdFunc(_ *cobra.Command, args []string) error { @@ -196,19 +220,16 @@ func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error { return nil } -func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error { - intFd, err := safecast.ToInt(uint(os.Stdout.Fd())) +func schemaWriteCmdImpl(cmd *cobra.Command, args []string, client v1.SchemaServiceClient, terminalChecker termChecker) error { + stdInFd, err := safecast.ToInt(uint(os.Stdin.Fd())) if err != nil { return err } - if len(args) == 0 && term.IsTerminal(intFd) { - return fmt.Errorf("must provide file path or contents via stdin") - } - client, err := client.NewClient(cmd) - if err != nil { - return err + if len(args) == 0 && terminalChecker.IsTerminal(stdInFd) { + return errors.New("must provide file path or contents via stdin") } + var schemaBytes []byte switch len(args) { case 1: @@ -246,18 +267,17 @@ func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error { resp, err := client.WriteSchema(cmd.Context(), request) if err != nil { - log.Fatal().Err(err).Msg("failed to write schema") + return fmt.Errorf("failed to write schema: %w", err) } log.Trace().Interface("response", resp).Msg("wrote schema") if cobrautil.MustGetBool(cmd, "json") { prettyProto, err := commands.PrettyProto(resp) if err != nil { - log.Fatal().Err(err).Msg("failed to convert schema to JSON") + return fmt.Errorf("failed to convert schema to JSON: %w", err) } console.Println(string(prettyProto)) - return nil } return nil @@ -287,7 +307,7 @@ func rewriteSchema(existingSchemaText string, definitionPrefix string) (string, // If specifiedPrefix is non-empty, it is returned immediately. // If existingSchema is non-nil, it is parsed for the prefix. // Otherwise, the client is used to retrieve the existing schema (if any), and the prefix is retrieved from there. -func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client client.Client, existingSchema *string) (string, error) { +func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client v1.SchemaServiceClient, existingSchema *string) (string, error) { if specifiedPrefix != "" { return specifiedPrefix, nil } @@ -340,14 +360,14 @@ func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, clien // Compiles an input schema written in the new composable schema syntax // and produces it as a fully-realized schema -func schemaCompileCmdFunc(cmd *cobra.Command, args []string) error { +func schemaCompileCmdFunc(cmd *cobra.Command, args []string, termChecker termChecker) error { stdOutFd, err := safecast.ToInt(uint(os.Stdout.Fd())) if err != nil { return err } outputFilepath := cobrautil.MustGetString(cmd, "out") - if outputFilepath == "" && !term.IsTerminal(stdOutFd) { - return fmt.Errorf("must provide stdout or output file path") + if outputFilepath == "" && !termChecker.IsTerminal(stdOutFd) { + return errors.New("must provide stdout or output file path") } inputFilepath := args[0] diff --git a/internal/cmd/schema_test.go b/internal/cmd/schema_test.go index 0ae8e6d..3bb37eb 100644 --- a/internal/cmd/schema_test.go +++ b/internal/cmd/schema_test.go @@ -1,13 +1,17 @@ package cmd import ( + "context" + "errors" "io/fs" "os" "path/filepath" "testing" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/spicedb/pkg/composableschemadsl/compiler" zedtesting "github.com/authzed/zed/internal/testing" @@ -166,12 +170,15 @@ definition resource { cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t, zedtesting.StringFlag{FlagName: "out", FlagValue: tempOutFile}) - err := schemaCompileCmdFunc(cmd, tc.files) + mockTermCheckerr := &mockTermChecker{returnVal: false} + err := schemaCompileCmdFunc(cmd, tc.files, mockTermCheckerr) if tc.expectErr == nil { require.NoError(err) tempOutString, err := os.ReadFile(tempOutFile) require.NoError(err) require.Equal(tc.expectStr, string(tempOutString)) + // TODO re-enable after adding a test that uses stdout + // require.Equal(int(os.Stdout.Fd()), mockTermCheckerr.capturedFd, "expected stdout to be checked for terminal") } else { require.Error(err) require.ErrorAs(err, &tc.expectErr) @@ -179,3 +186,168 @@ definition resource { }) } } + +func TestSchemaWrite(t *testing.T) { + t.Parallel() + + // Save original stdin + oldStdin := os.Stdin + t.Cleanup(func() { + os.Stdin = oldStdin + }) + + testCases := map[string]struct { + schemaMakerFn func() ([]string, error) + terminalChecker *mockTermChecker + expectErr string + expectSchemaWritten string + }{ + `schema_from_file`: { + schemaMakerFn: func() ([]string, error) { + return []string{ + filepath.Join("write-schema-test", "basic.zed"), + }, nil + }, + expectSchemaWritten: `definition user {} +definition resource { + relation view: user + permission viewer = view +}`, + terminalChecker: &mockTermChecker{returnVal: false}, + }, + `schema_from_stdin`: { + schemaMakerFn: func() ([]string, error) { + schemaContent := "definition user{}\ndefinition document { relation read: user }" + pipeRead, pipeWrite, err := os.Pipe() + require.NoError(t, err) + os.Stdin = pipeRead + _, err = pipeWrite.WriteString(schemaContent) + require.NoError(t, err) + err = pipeWrite.Close() + require.NoError(t, err) + return []string{}, nil + }, + terminalChecker: &mockTermChecker{returnVal: false}, + expectSchemaWritten: "definition user{}\ndefinition document { relation read: user }", + }, + `schema_from_stdin_but_terminal`: { + schemaMakerFn: func() ([]string, error) { + schemaContent := "definition user{}\ndefinition document { relation read: user }" + pipeRead, pipeWrite, err := os.Pipe() + require.NoError(t, err) + os.Stdin = pipeRead + _, err = pipeWrite.WriteString(schemaContent) + require.NoError(t, err) + err = pipeWrite.Close() + require.NoError(t, err) + return []string{}, nil + }, + terminalChecker: &mockTermChecker{returnVal: true}, + expectErr: "must provide file path or contents via stdin", + }, + `empty_schema_errors`: { + schemaMakerFn: func() ([]string, error) { + pipeRead, pipeWrite, err := os.Pipe() + require.NoError(t, err) + os.Stdin = pipeRead + _, err = pipeWrite.WriteString("") + require.NoError(t, err) + err = pipeWrite.Close() + require.NoError(t, err) + return []string{}, nil + }, + terminalChecker: &mockTermChecker{returnVal: false}, + expectErr: "attempted to write empty schema", + }, + `write_failure_errors`: { + schemaMakerFn: func() ([]string, error) { + return []string{ + filepath.Join("write-schema-test", "basic.zed"), + }, errors.New("write error") + }, + terminalChecker: &mockTermChecker{returnVal: false}, + expectErr: "error writing schema", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t, + zedtesting.StringFlag{FlagName: "schema-definition-prefix", FlagValue: ""}, + zedtesting.BoolFlag{FlagName: "json", FlagValue: true}, + ) + + args, writeErr := tc.schemaMakerFn() + mockWriteSchemaClientt := &mockWriteSchemaClient{} + if writeErr != nil { + mockWriteSchemaClientt.writeReturnsError = true + } + + err := schemaWriteCmdImpl(cmd, args, mockWriteSchemaClientt, tc.terminalChecker) + + if tc.expectErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.expectErr) + return + } + + require.NoError(t, err) + require.Equal(t, tc.expectSchemaWritten, mockWriteSchemaClientt.receivedSchema) + if tc.terminalChecker.captured { + require.Equal(t, int(os.Stdin.Fd()), tc.terminalChecker.capturedFd, "expected stdin to be checked for terminal") + } + }) + } +} + +type mockWriteSchemaClient struct { + existingSchema string + receivedSchema string + writeReturnsError bool +} + +var _ v1.SchemaServiceClient = (*mockWriteSchemaClient)(nil) + +func (m *mockWriteSchemaClient) WriteSchema(_ context.Context, in *v1.WriteSchemaRequest, _ ...grpc.CallOption) (*v1.WriteSchemaResponse, error) { + if m.writeReturnsError { + return nil, errors.New("error writing schema") + } + m.receivedSchema = in.Schema + return &v1.WriteSchemaResponse{}, nil +} + +func (m *mockWriteSchemaClient) ReadSchema(_ context.Context, _ *v1.ReadSchemaRequest, _ ...grpc.CallOption) (*v1.ReadSchemaResponse, error) { + return &v1.ReadSchemaResponse{ + SchemaText: m.existingSchema, + }, nil +} + +func (m *mockWriteSchemaClient) ReflectSchema(_ context.Context, _ *v1.ReflectSchemaRequest, _ ...grpc.CallOption) (*v1.ReflectSchemaResponse, error) { + panic("not implemented") +} + +func (m *mockWriteSchemaClient) ComputablePermissions(_ context.Context, _ *v1.ComputablePermissionsRequest, _ ...grpc.CallOption) (*v1.ComputablePermissionsResponse, error) { + panic("not implemented") +} + +func (m *mockWriteSchemaClient) DependentRelations(_ context.Context, _ *v1.DependentRelationsRequest, _ ...grpc.CallOption) (*v1.DependentRelationsResponse, error) { + panic("not implemented") +} + +func (m *mockWriteSchemaClient) DiffSchema(_ context.Context, _ *v1.DiffSchemaRequest, _ ...grpc.CallOption) (*v1.DiffSchemaResponse, error) { + panic("not implemented") +} + +type mockTermChecker struct { + returnVal bool + captured bool + capturedFd int +} + +var _ termChecker = (*mockTermChecker)(nil) + +func (m *mockTermChecker) IsTerminal(fd int) bool { + m.captured = true + m.capturedFd = fd + return m.returnVal +} diff --git a/internal/cmd/write-schema-test/basic.zed b/internal/cmd/write-schema-test/basic.zed new file mode 100644 index 0000000..522e90c --- /dev/null +++ b/internal/cmd/write-schema-test/basic.zed @@ -0,0 +1,5 @@ +definition user {} +definition resource { + relation view: user + permission viewer = view +} \ No newline at end of file diff --git a/internal/commands/schema.go b/internal/commands/schema.go index b56f152..cb02074 100644 --- a/internal/commands/schema.go +++ b/internal/commands/schema.go @@ -68,7 +68,7 @@ func schemaReadCmdFunc(cmd *cobra.Command, _ []string) error { } // ReadSchema calls read schema for the client and returns the schema found. -func ReadSchema(ctx context.Context, client client.Client) (string, error) { +func ReadSchema(ctx context.Context, client v1.SchemaServiceClient) (string, error) { request := &v1.ReadSchemaRequest{} log.Trace().Interface("request", request).Msg("requesting schema read")