From b63f5e144830c4ac8fc8863e97e79f564e1ac151 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Rold=C3=A1n=20Betancort?= Date: Wed, 23 Apr 2025 15:06:02 +0100 Subject: [PATCH] print usage only on validation errors the usage was being printed on any error, which was rather unpleasant to look at because it makes the output difficult to parse for clues on what happened. The whole goal of this commit is to improve UX. Cobra supports "SilenceErrors" and "SilenceUsage", and a callback function "SetFlagErrorFunc" for flag parsing errors. Unfortunately, only a subset of errors fall under "flag parsing error" category. To amend this, a special error Wrapper is introduced that signals an error is a validation error, so the command can manually print the usage when "SilenceUsage" is enabled, which this commit does. A few tests are added that assert the output. Due to the pervasive use of globals, some refactoring was conducted in cmd.go to make it more test friendly. --- internal/cmd/backup.go | 16 ++--- internal/cmd/cmd.go | 60 ++++++++++++---- internal/cmd/cmd_test.go | 114 ++++++++++++++++++++++++++++++ internal/cmd/context.go | 9 +-- internal/cmd/import.go | 3 +- internal/cmd/preview.go | 2 +- internal/cmd/schema.go | 6 +- internal/cmd/validate.go | 2 +- internal/commands/permission.go | 12 ++-- internal/commands/relationship.go | 10 +-- internal/commands/schema.go | 2 +- internal/commands/util.go | 25 +++++++ internal/commands/util_test.go | 44 ++++++++++++ internal/commands/watch.go | 4 +- 14 files changed, 265 insertions(+), 44 deletions(-) create mode 100644 internal/cmd/cmd_test.go create mode 100644 internal/commands/util_test.go diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index e66414a4..57fa83c8 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -53,7 +53,7 @@ var ( backupCmd = &cobra.Command{ Use: "backup ", Short: "Create, restore, and inspect permissions system backups", - Args: cobra.MaximumNArgs(1), + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), // Create used to be on the root, so add it here for back-compat. RunE: withErrorHandling(backupCreateCmdFunc), } @@ -61,21 +61,21 @@ var ( backupCreateCmd = &cobra.Command{ Use: "create ", Short: "Backup a permission system to a file", - Args: cobra.MaximumNArgs(1), + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), RunE: withErrorHandling(backupCreateCmdFunc), } backupRestoreCmd = &cobra.Command{ Use: "restore ", Short: "Restore a permission system from a file", - Args: commands.StdinOrExactArgs(1), + Args: commands.ValidationWrapper(commands.StdinOrExactArgs(1)), RunE: backupRestoreCmdFunc, } backupParseSchemaCmd = &cobra.Command{ Use: "parse-schema ", Short: "Extract the schema from a backup file", - Args: cobra.ExactArgs(1), + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), RunE: func(cmd *cobra.Command, args []string) error { return backupParseSchemaCmdFunc(cmd, os.Stdout, args) }, @@ -84,7 +84,7 @@ var ( backupParseRevisionCmd = &cobra.Command{ Use: "parse-revision ", Short: "Extract the revision from a backup file", - Args: cobra.ExactArgs(1), + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), RunE: func(cmd *cobra.Command, args []string) error { return backupParseRevisionCmdFunc(cmd, os.Stdout, args) }, @@ -93,7 +93,7 @@ var ( backupParseRelsCmd = &cobra.Command{ Use: "parse-relationships ", Short: "Extract the relationships from a backup file", - Args: cobra.ExactArgs(1), + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), RunE: func(cmd *cobra.Command, args []string) error { return backupParseRelsCmdFunc(cmd, os.Stdout, args) }, @@ -102,7 +102,7 @@ var ( backupRedactCmd = &cobra.Command{ Use: "redact ", Short: "Redact a backup file to remove sensitive information", - Args: cobra.ExactArgs(1), + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), RunE: func(cmd *cobra.Command, args []string) error { return backupRedactCmdFunc(cmd, args) }, @@ -129,7 +129,7 @@ func registerBackupCmd(rootCmd *cobra.Command) { restoreCmd := &cobra.Command{ Use: "restore ", Short: "Restore a permission system from a backup file", - Args: cobra.MaximumNArgs(1), + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), RunE: backupRestoreCmdFunc, Hidden: true, } diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 99df3045..db15c1dd 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -6,6 +6,7 @@ import ( "io" "os" "os/signal" + "strings" "syscall" "github.com/jzelinskie/cobrautil/v2" @@ -42,7 +43,15 @@ func init() { log.Logger = l } -// This function is utilised to generate docs for zed +var flagError = flagErrorFunc + +func flagErrorFunc(cmd *cobra.Command, err error) error { + cmd.Println(err) + cmd.Println(cmd.UsageString()) + return errParsing +} + +// InitialiseRootCmd This function is utilised to generate docs for zed func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command { rootCmd := &cobra.Command{ Use: "zed", @@ -54,12 +63,10 @@ func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command { commands.InjectRequestID, ), SilenceErrors: true, - SilenceUsage: false, + SilenceUsage: true, } - rootCmd.SetFlagErrorFunc(func(cmd *cobra.Command, err error) error { - cmd.Println(err) - cmd.Println(cmd.UsageString()) - return errParsing + rootCmd.SetFlagErrorFunc(func(command *cobra.Command, err error) error { + return flagError(command, err) }) zl.RegisterFlags(rootCmd.PersistentFlags()) @@ -92,7 +99,7 @@ func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command { rootCmd.AddCommand(&cobra.Command{ Use: "use ", Short: "Alias for `zed context use`", - Args: cobra.MaximumNArgs(1), + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), RunE: contextUseCmdFunc, ValidArgsFunction: ContextGet, }) @@ -119,6 +126,12 @@ func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command { } func Run() { + if err := runWithoutExit(); err != nil { + os.Exit(1) + } +} + +func runWithoutExit() error { zl := cobrazerolog.New(cobrazerolog.WithPreRunLevel(zerolog.DebugLevel)) rootCmd := InitialiseRootCmd(zl) @@ -141,11 +154,34 @@ func Run() { } }() - if err := rootCmd.ExecuteContext(ctx); err != nil { - if !errors.Is(err, errParsing) { - log.Err(err).Msg("terminated with errors") - } + return handleError(rootCmd, rootCmd.ExecuteContext(ctx)) +} - os.Exit(1) +func handleError(command *cobra.Command, err error) error { + if err == nil { + return nil + } + // this snippet of code is taken from Command.ExecuteC in order to determine the command that was ultimately + // parsed. This is necessary to be able to print the proper command-specific usage + var findErr error + var cmdToExecute *cobra.Command + args := os.Args[1:] + if command.TraverseChildren { + cmdToExecute, _, findErr = command.Traverse(args) + } else { + cmdToExecute, _, findErr = command.Find(args) } + if findErr != nil { + cmdToExecute = command + } + + if errors.Is(err, commands.ValidationError{}) { + _ = flagError(cmdToExecute, err) + } else if err != nil && strings.Contains(err.Error(), "unknown command") { + _ = flagError(cmdToExecute, err) + } else if !errors.Is(err, errParsing) { + log.Err(err).Msg("terminated with errors") + } + + return err } diff --git a/internal/cmd/cmd_test.go b/internal/cmd/cmd_test.go new file mode 100644 index 00000000..d86a4d09 --- /dev/null +++ b/internal/cmd/cmd_test.go @@ -0,0 +1,114 @@ +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "github.com/google/uuid" + "github.com/jzelinskie/cobrautil/v2/cobrazerolog" + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +// note: these tests mess with global variables, so do not run in parallel with other tests. +func TestCommandOutput(t *testing.T) { + cases := []struct { + name string + flagErrorContains string + expectUsageContains string + expectFlagErrorCalled bool + expectStdErrorMsg string + command []string + }{ + { + name: "prints usage on invalid command error", + command: []string{"zed", "madeupcommand"}, + expectFlagErrorCalled: true, + flagErrorContains: "unknown command", + expectUsageContains: "zed [command]", + }, + { + name: "prints usage on invalid flag error", + command: []string{"zed", "version", "--madeupflag"}, + expectFlagErrorCalled: true, + flagErrorContains: "unknown flag: --madeupflag", + expectUsageContains: "zed version [flags]", + }, + { + name: "prints usage on parameter validation error", + command: []string{"zed", "validate"}, + expectFlagErrorCalled: true, + flagErrorContains: "requires at least 1 arg(s), only received 0", + expectUsageContains: "zed validate [flags]", + }, + { + name: "prints correct usage", + command: []string{"zed", "perm", "check"}, + expectFlagErrorCalled: true, + flagErrorContains: "accepts 3 arg(s), received 0", + expectUsageContains: "zed permission check ", + }, + { + name: "does not print usage on command error", + command: []string{"zed", "validate", uuid.NewString()}, + expectFlagErrorCalled: false, + expectStdErrorMsg: "terminated with errors", + }, + } + + zl := cobrazerolog.New(cobrazerolog.WithPreRunLevel(zerolog.DebugLevel)) + + rootCmd := InitialiseRootCmd(zl) + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + var flagErrorCalled bool + testFlagError := func(cmd *cobra.Command, err error) error { + require.ErrorContains(t, err, tt.flagErrorContains) + require.Contains(t, cmd.UsageString(), tt.expectUsageContains) + flagErrorCalled = true + return errParsing + } + stderrFile := setupOutputForTest(t, testFlagError, tt.command...) + + err := handleError(rootCmd, rootCmd.ExecuteContext(t.Context())) + require.Error(t, err) + stdErrBytes, err := os.ReadFile(stderrFile) + require.NoError(t, err) + if tt.expectStdErrorMsg != "" { + require.Contains(t, string(stdErrBytes), tt.expectStdErrorMsg) + } else { + require.Len(t, stdErrBytes, 0) + } + require.Equal(t, tt.expectFlagErrorCalled, flagErrorCalled) + }) + } +} + +func setupOutputForTest(t *testing.T, testFlagError func(cmd *cobra.Command, err error) error, args ...string) string { + t.Helper() + + originalLevel := zerolog.GlobalLevel() + originalFlagError := flagError + originalArgs := os.Args + originalStderr := os.Stderr + t.Cleanup(func() { + zerolog.SetGlobalLevel(originalLevel) + flagError = originalFlagError + os.Args = originalArgs + os.Stderr = originalStderr + }) + + if len(args) > 0 { + os.Args = args + } + flagError = testFlagError + zerolog.SetGlobalLevel(zerolog.TraceLevel) + tempStdErrFileName := filepath.Join(t.TempDir(), uuid.NewString()) + tempStdErr, err := os.Create(tempStdErrFileName) + require.NoError(t, err) + + os.Stderr = tempStdErr + return tempStdErrFileName +} diff --git a/internal/cmd/context.go b/internal/cmd/context.go index 1f34e43d..d5a0eec4 100644 --- a/internal/cmd/context.go +++ b/internal/cmd/context.go @@ -9,6 +9,7 @@ import ( "github.com/spf13/cobra" "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" "github.com/authzed/zed/internal/console" "github.com/authzed/zed/internal/printers" "github.com/authzed/zed/internal/storage" @@ -35,7 +36,7 @@ var contextListCmd = &cobra.Command{ Use: "list", Short: "Lists all available contexts", Aliases: []string{"ls"}, - Args: cobra.ExactArgs(0), + Args: commands.ValidationWrapper(cobra.ExactArgs(0)), ValidArgsFunction: cobra.NoFileCompletions, RunE: contextListCmdFunc, } @@ -43,7 +44,7 @@ var contextListCmd = &cobra.Command{ var contextSetCmd = &cobra.Command{ Use: "set ", Short: "Creates or overwrite a context", - Args: cobra.ExactArgs(3), + Args: commands.ValidationWrapper(cobra.ExactArgs(3)), ValidArgsFunction: cobra.NoFileCompletions, RunE: contextSetCmdFunc, } @@ -52,7 +53,7 @@ var contextRemoveCmd = &cobra.Command{ Use: "remove ", Short: "Removes a context", Aliases: []string{"rm"}, - Args: cobra.ExactArgs(1), + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), ValidArgsFunction: ContextGet, RunE: contextRemoveCmdFunc, } @@ -60,7 +61,7 @@ var contextRemoveCmd = &cobra.Command{ var contextUseCmd = &cobra.Command{ Use: "use ", Short: "Sets a context as the current context", - Args: cobra.MaximumNArgs(1), + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), ValidArgsFunction: ContextGet, RunE: contextUseCmdFunc, } diff --git a/internal/cmd/import.go b/internal/cmd/import.go index 3cb11f67..687d42ef 100644 --- a/internal/cmd/import.go +++ b/internal/cmd/import.go @@ -16,6 +16,7 @@ import ( "github.com/authzed/spicedb/pkg/validationfile" "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" "github.com/authzed/zed/internal/decode" "github.com/authzed/zed/internal/grpcutil" ) @@ -60,7 +61,7 @@ var importCmd = &cobra.Command{ With schema definition prefix: zed import --schema-definition-prefix=mypermsystem file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml `, - Args: cobra.ExactArgs(1), + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), RunE: importCmdFunc, } diff --git a/internal/cmd/preview.go b/internal/cmd/preview.go index 90761734..279e310d 100644 --- a/internal/cmd/preview.go +++ b/internal/cmd/preview.go @@ -41,7 +41,7 @@ var schemaCmd = &cobra.Command{ var schemaCompileCmd = &cobra.Command{ Use: "compile ", - Args: cobra.ExactArgs(1), + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), Short: "Compile a schema that uses extended syntax into one that can be written to SpiceDB", Example: ` Write to stdout: diff --git a/internal/cmd/schema.go b/internal/cmd/schema.go index 235e9c23..bbe52f39 100644 --- a/internal/cmd/schema.go +++ b/internal/cmd/schema.go @@ -41,7 +41,7 @@ func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) { var schemaWriteCmd = &cobra.Command{ Use: "write ", - Args: cobra.MaximumNArgs(1), + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), Short: "Write a schema file (.zed or stdin) to the current permissions system", ValidArgsFunction: commands.FileExtensionCompletions("zed"), RunE: schemaWriteCmdFunc, @@ -50,7 +50,7 @@ var schemaWriteCmd = &cobra.Command{ var schemaCopyCmd = &cobra.Command{ Use: "copy ", Short: "Copy a schema from one context into another", - Args: cobra.ExactArgs(2), + Args: commands.ValidationWrapper(cobra.ExactArgs(2)), ValidArgsFunction: ContextGet, RunE: schemaCopyCmdFunc, } @@ -58,7 +58,7 @@ var schemaCopyCmd = &cobra.Command{ var schemaDiffCmd = &cobra.Command{ Use: "diff ", Short: "Diff two schema files", - Args: cobra.ExactArgs(2), + Args: commands.ValidationWrapper(cobra.ExactArgs(2)), RunE: schemaDiffCmdFunc, } diff --git a/internal/cmd/validate.go b/internal/cmd/validate.go index 2a61e8e2..8c4aa265 100644 --- a/internal/cmd/validate.go +++ b/internal/cmd/validate.go @@ -71,7 +71,7 @@ var validateCmd = &cobra.Command{ From a devtools instance: zed validate https://localhost:8443/download`, - Args: cobra.MinimumNArgs(1), + Args: commands.ValidationWrapper(cobra.MinimumNArgs(1)), ValidArgsFunction: commands.FileExtensionCompletions("zed", "yaml", "zaml"), PreRunE: validatePreRunE, RunE: func(cmd *cobra.Command, filenames []string) error { diff --git a/internal/commands/permission.go b/internal/commands/permission.go index 5ea68a4e..58b022c1 100644 --- a/internal/commands/permission.go +++ b/internal/commands/permission.go @@ -129,14 +129,14 @@ var permissionCmd = &cobra.Command{ var checkBulkCmd = &cobra.Command{ Use: "bulk ...", Short: "Check a permissions in bulk exists for a resource-subject pairs", - Args: cobra.MinimumNArgs(1), + Args: ValidationWrapper(cobra.MinimumNArgs(1)), RunE: checkBulkCmdFunc, } var checkCmd = &cobra.Command{ Use: "check ", Short: "Check that a permission exists for a subject", - Args: cobra.ExactArgs(3), + Args: ValidationWrapper(cobra.ExactArgs(3)), ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectID), RunE: checkCmdFunc, } @@ -144,7 +144,7 @@ var checkCmd = &cobra.Command{ var expandCmd = &cobra.Command{ Use: "expand ", Short: "Expand the structure of a permission", - Args: cobra.ExactArgs(2), + Args: ValidationWrapper(cobra.ExactArgs(2)), ValidArgsFunction: cobra.NoFileCompletions, RunE: expandCmdFunc, } @@ -152,7 +152,7 @@ var expandCmd = &cobra.Command{ var lookupResourcesCmd = &cobra.Command{ Use: "lookup-resources ", Short: "Enumerates resources of a given type for which the subject has permission", - Args: cobra.ExactArgs(3), + Args: ValidationWrapper(cobra.ExactArgs(3)), ValidArgsFunction: GetArgs(ResourceType, Permission, SubjectID), RunE: lookupResourcesCmdFunc, } @@ -160,7 +160,7 @@ var lookupResourcesCmd = &cobra.Command{ var lookupCmd = &cobra.Command{ Use: "lookup ", Short: "Enumerates the resources of a given type for which the subject has permission", - Args: cobra.ExactArgs(3), + Args: ValidationWrapper(cobra.ExactArgs(3)), ValidArgsFunction: GetArgs(ResourceType, Permission, SubjectID), RunE: lookupResourcesCmdFunc, Deprecated: "prefer lookup-resources", @@ -170,7 +170,7 @@ var lookupCmd = &cobra.Command{ var lookupSubjectsCmd = &cobra.Command{ Use: "lookup-subjects ", Short: "Enumerates the subjects of a given type for which the subject has permission on the resource", - Args: cobra.ExactArgs(3), + Args: ValidationWrapper(cobra.ExactArgs(3)), ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), RunE: lookupSubjectsCmdFunc, } diff --git a/internal/commands/relationship.go b/internal/commands/relationship.go index 730c617e..0471f856 100644 --- a/internal/commands/relationship.go +++ b/internal/commands/relationship.go @@ -70,7 +70,7 @@ var relationshipCmd = &cobra.Command{ var createCmd = &cobra.Command{ Use: "create ", Short: "Create a relationship for a subject", - Args: StdinOrExactArgs(3), + Args: ValidationWrapper(StdinOrExactArgs(3)), ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), RunE: writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_CREATE, os.Stdin), } @@ -78,7 +78,7 @@ var createCmd = &cobra.Command{ var touchCmd = &cobra.Command{ Use: "touch ", Short: "Idempotently updates a relationship for a subject", - Args: StdinOrExactArgs(3), + Args: ValidationWrapper(StdinOrExactArgs(3)), ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), RunE: writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_TOUCH, os.Stdin), } @@ -86,7 +86,7 @@ var touchCmd = &cobra.Command{ var deleteCmd = &cobra.Command{ Use: "delete ", Short: "Deletes a relationship", - Args: StdinOrExactArgs(3), + Args: ValidationWrapper(StdinOrExactArgs(3)), ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), RunE: writeRelationshipCmdFunc(v1.RelationshipUpdate_OPERATION_DELETE, os.Stdin), } @@ -102,7 +102,7 @@ var readCmd = &cobra.Command{ Use: "read ", Short: "Enumerates relationships matching the provided pattern", Long: readCmdHelpLong, - Args: cobra.RangeArgs(1, 3), + Args: ValidationWrapper(cobra.RangeArgs(1, 3)), ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), RunE: readRelationships, } @@ -110,7 +110,7 @@ var readCmd = &cobra.Command{ var bulkDeleteCmd = &cobra.Command{ Use: "bulk-delete ", Short: "Deletes relationships matching the provided pattern en masse", - Args: cobra.RangeArgs(1, 3), + Args: ValidationWrapper(cobra.RangeArgs(1, 3)), ValidArgsFunction: GetArgs(ResourceID, Permission, SubjectTypeWithOptionalRelation), RunE: bulkDeleteRelationships, } diff --git a/internal/commands/schema.go b/internal/commands/schema.go index 84d226e7..b56f152e 100644 --- a/internal/commands/schema.go +++ b/internal/commands/schema.go @@ -34,7 +34,7 @@ var ( schemaReadCmd = &cobra.Command{ Use: "read", Short: "Read the schema of a permissions system", - Args: cobra.ExactArgs(0), + Args: ValidationWrapper(cobra.ExactArgs(0)), ValidArgsFunction: cobra.NoFileCompletions, RunE: schemaReadCmdFunc, } diff --git a/internal/commands/util.go b/internal/commands/util.go index a220f40c..6d5b4da0 100644 --- a/internal/commands/util.go +++ b/internal/commands/util.go @@ -2,6 +2,7 @@ package commands import ( "encoding/json" + "errors" "fmt" "strings" @@ -96,3 +97,27 @@ func InjectRequestID(cmd *cobra.Command, _ []string) error { return nil } + +// ValidationError is used to wrap errors that are cobra validation errors. It should be used to +// wrap the Command.PositionalArgs function in order to be able to determine if the error is a validation error. +// This is used to determine if an error should print the usage string. Unfortunately Cobra parameter parsing +// and parameter validation are handled differently, and the latter does not trigger calling Command.FlagErrorFunc +type ValidationError struct { + error +} + +func (ve ValidationError) Is(err error) bool { + var validationError ValidationError + return errors.As(err, &validationError) +} + +// ValidationWrapper is used to be able to determine if an error is a validation error. +func ValidationWrapper(f cobra.PositionalArgs) cobra.PositionalArgs { + return func(cmd *cobra.Command, args []string) error { + if err := f(cmd, args); err != nil { + return ValidationError{error: err} + } + + return nil + } +} diff --git a/internal/commands/util_test.go b/internal/commands/util_test.go new file mode 100644 index 00000000..aa728c82 --- /dev/null +++ b/internal/commands/util_test.go @@ -0,0 +1,44 @@ +package commands + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +func TestValidationWrapper(t *testing.T) { + tests := []struct { + name string + positionalArgs cobra.PositionalArgs + args []string + wantErr bool + }{ + { + name: "valid args", + positionalArgs: cobra.MaximumNArgs(2), + args: []string{"arg1", "arg2"}, + wantErr: false, + }, + { + name: "invalid args", + positionalArgs: cobra.MaximumNArgs(2), + args: []string{"arg1", "arg2", "arg3"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidationWrapper(tt.positionalArgs)(nil, tt.args) + if tt.wantErr { + var validationError ValidationError + require.ErrorAs(t, err, &validationError) + require.NotNil(t, validationError.error) + require.ErrorContains(t, validationError.error, "accepts at most") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/commands/watch.go b/internal/commands/watch.go index bc955063..96b34516 100644 --- a/internal/commands/watch.go +++ b/internal/commands/watch.go @@ -45,7 +45,7 @@ func RegisterWatchRelationshipCmd(parentCmd *cobra.Command) *cobra.Command { var watchCmd = &cobra.Command{ Use: "watch [object_types, ...] [start_cursor]", Short: "Watches the stream of relationship updates from the server", - Args: cobra.RangeArgs(0, 2), + Args: ValidationWrapper(cobra.RangeArgs(0, 2)), RunE: watchCmdFunc, Deprecated: "deprecated; please use `zed watch relationships` instead", } @@ -53,7 +53,7 @@ var watchCmd = &cobra.Command{ var watchRelationshipsCmd = &cobra.Command{ Use: "watch [object_types, ...] [start_cursor]", Short: "Watches the stream of relationship updates from the server", - Args: cobra.RangeArgs(0, 2), + Args: ValidationWrapper(cobra.RangeArgs(0, 2)), RunE: watchCmdFunc, }