Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions internal/cmd/backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,29 @@ var (
backupCmd = &cobra.Command{
Use: "backup <filename>",
Short: "Create, restore, and inspect permissions system backups",
Args: cobra.MaximumNArgs(1),
Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
// Create used to be on the root, so add it here for back-compat.
RunE: withErrorHandling(backupCreateCmdFunc),
}

backupCreateCmd = &cobra.Command{
Use: "create <filename>",
Short: "Backup a permission system to a file",
Args: cobra.MaximumNArgs(1),
Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
RunE: withErrorHandling(backupCreateCmdFunc),
}

backupRestoreCmd = &cobra.Command{
Use: "restore <filename>",
Short: "Restore a permission system from a file",
Args: commands.StdinOrExactArgs(1),
Args: commands.ValidationWrapper(commands.StdinOrExactArgs(1)),
RunE: backupRestoreCmdFunc,
}

backupParseSchemaCmd = &cobra.Command{
Use: "parse-schema <filename>",
Short: "Extract the schema from a backup file",
Args: cobra.ExactArgs(1),
Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
RunE: func(cmd *cobra.Command, args []string) error {
return backupParseSchemaCmdFunc(cmd, os.Stdout, args)
},
Expand All @@ -84,7 +84,7 @@ var (
backupParseRevisionCmd = &cobra.Command{
Use: "parse-revision <filename>",
Short: "Extract the revision from a backup file",
Args: cobra.ExactArgs(1),
Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
RunE: func(cmd *cobra.Command, args []string) error {
return backupParseRevisionCmdFunc(cmd, os.Stdout, args)
},
Expand All @@ -93,7 +93,7 @@ var (
backupParseRelsCmd = &cobra.Command{
Use: "parse-relationships <filename>",
Short: "Extract the relationships from a backup file",
Args: cobra.ExactArgs(1),
Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
RunE: func(cmd *cobra.Command, args []string) error {
return backupParseRelsCmdFunc(cmd, os.Stdout, args)
},
Expand All @@ -102,7 +102,7 @@ var (
backupRedactCmd = &cobra.Command{
Use: "redact <filename>",
Short: "Redact a backup file to remove sensitive information",
Args: cobra.ExactArgs(1),
Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
RunE: func(cmd *cobra.Command, args []string) error {
return backupRedactCmdFunc(cmd, args)
},
Expand All @@ -129,7 +129,7 @@ func registerBackupCmd(rootCmd *cobra.Command) {
restoreCmd := &cobra.Command{
Use: "restore <filename>",
Short: "Restore a permission system from a backup file",
Args: cobra.MaximumNArgs(1),
Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
RunE: backupRestoreCmdFunc,
Hidden: true,
}
Expand Down
60 changes: 48 additions & 12 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"os"
"os/signal"
"strings"
"syscall"

"github.com/jzelinskie/cobrautil/v2"
Expand Down Expand Up @@ -42,7 +43,15 @@ func init() {
log.Logger = l
}

// This function is utilised to generate docs for zed
var flagError = flagErrorFunc

func flagErrorFunc(cmd *cobra.Command, err error) error {
cmd.Println(err)
cmd.Println(cmd.UsageString())
return errParsing
}

// InitialiseRootCmd This function is utilised to generate docs for zed
func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command {
rootCmd := &cobra.Command{
Use: "zed",
Expand All @@ -54,12 +63,10 @@ func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command {
commands.InjectRequestID,
),
SilenceErrors: true,
SilenceUsage: false,
SilenceUsage: true,
}
rootCmd.SetFlagErrorFunc(func(cmd *cobra.Command, err error) error {
cmd.Println(err)
cmd.Println(cmd.UsageString())
return errParsing
rootCmd.SetFlagErrorFunc(func(command *cobra.Command, err error) error {
return flagError(command, err)
})

zl.RegisterFlags(rootCmd.PersistentFlags())
Expand Down Expand Up @@ -92,7 +99,7 @@ func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command {
rootCmd.AddCommand(&cobra.Command{
Use: "use <context>",
Short: "Alias for `zed context use`",
Args: cobra.MaximumNArgs(1),
Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
RunE: contextUseCmdFunc,
ValidArgsFunction: ContextGet,
})
Expand All @@ -119,6 +126,12 @@ func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command {
}

func Run() {
if err := runWithoutExit(); err != nil {
os.Exit(1)
}
}

func runWithoutExit() error {
zl := cobrazerolog.New(cobrazerolog.WithPreRunLevel(zerolog.DebugLevel))

rootCmd := InitialiseRootCmd(zl)
Expand All @@ -141,11 +154,34 @@ func Run() {
}
}()

if err := rootCmd.ExecuteContext(ctx); err != nil {
if !errors.Is(err, errParsing) {
log.Err(err).Msg("terminated with errors")
}
return handleError(rootCmd, rootCmd.ExecuteContext(ctx))
}

os.Exit(1)
func handleError(command *cobra.Command, err error) error {
if err == nil {
return nil
}
// this snippet of code is taken from Command.ExecuteC in order to determine the command that was ultimately
// parsed. This is necessary to be able to print the proper command-specific usage
var findErr error
var cmdToExecute *cobra.Command
args := os.Args[1:]
if command.TraverseChildren {
cmdToExecute, _, findErr = command.Traverse(args)
} else {
cmdToExecute, _, findErr = command.Find(args)
}
if findErr != nil {
cmdToExecute = command
}

if errors.Is(err, commands.ValidationError{}) {
_ = flagError(cmdToExecute, err)
} else if err != nil && strings.Contains(err.Error(), "unknown command") {
_ = flagError(cmdToExecute, err)
} else if !errors.Is(err, errParsing) {
log.Err(err).Msg("terminated with errors")
}

return err
}
114 changes: 114 additions & 0 deletions internal/cmd/cmd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package cmd

import (
"os"
"path/filepath"
"testing"

"github.com/google/uuid"
"github.com/jzelinskie/cobrautil/v2/cobrazerolog"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)

// note: these tests mess with global variables, so do not run in parallel with other tests.
func TestCommandOutput(t *testing.T) {
cases := []struct {
name string
flagErrorContains string
expectUsageContains string
expectFlagErrorCalled bool
expectStdErrorMsg string
command []string
}{
{
name: "prints usage on invalid command error",
command: []string{"zed", "madeupcommand"},
expectFlagErrorCalled: true,
flagErrorContains: "unknown command",
expectUsageContains: "zed [command]",
},
{
name: "prints usage on invalid flag error",
command: []string{"zed", "version", "--madeupflag"},
expectFlagErrorCalled: true,
flagErrorContains: "unknown flag: --madeupflag",
expectUsageContains: "zed version [flags]",
},
{
name: "prints usage on parameter validation error",
command: []string{"zed", "validate"},
expectFlagErrorCalled: true,
flagErrorContains: "requires at least 1 arg(s), only received 0",
expectUsageContains: "zed validate <validation_file_or_schema_file> [flags]",
},
{
name: "prints correct usage",
command: []string{"zed", "perm", "check"},
expectFlagErrorCalled: true,
flagErrorContains: "accepts 3 arg(s), received 0",
expectUsageContains: "zed permission check <resource:id> <permission> <subject:id>",
},
{
name: "does not print usage on command error",
command: []string{"zed", "validate", uuid.NewString()},
expectFlagErrorCalled: false,
expectStdErrorMsg: "terminated with errors",
},
}

zl := cobrazerolog.New(cobrazerolog.WithPreRunLevel(zerolog.DebugLevel))

rootCmd := InitialiseRootCmd(zl)
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var flagErrorCalled bool
testFlagError := func(cmd *cobra.Command, err error) error {
require.ErrorContains(t, err, tt.flagErrorContains)
require.Contains(t, cmd.UsageString(), tt.expectUsageContains)
flagErrorCalled = true
return errParsing
}
stderrFile := setupOutputForTest(t, testFlagError, tt.command...)

err := handleError(rootCmd, rootCmd.ExecuteContext(t.Context()))
require.Error(t, err)
stdErrBytes, err := os.ReadFile(stderrFile)
require.NoError(t, err)
if tt.expectStdErrorMsg != "" {
require.Contains(t, string(stdErrBytes), tt.expectStdErrorMsg)
} else {
require.Len(t, stdErrBytes, 0)
}
require.Equal(t, tt.expectFlagErrorCalled, flagErrorCalled)
})
}
}

func setupOutputForTest(t *testing.T, testFlagError func(cmd *cobra.Command, err error) error, args ...string) string {
t.Helper()

originalLevel := zerolog.GlobalLevel()
originalFlagError := flagError
originalArgs := os.Args
originalStderr := os.Stderr
t.Cleanup(func() {
zerolog.SetGlobalLevel(originalLevel)
flagError = originalFlagError
os.Args = originalArgs
os.Stderr = originalStderr
})

if len(args) > 0 {
os.Args = args
}
flagError = testFlagError
zerolog.SetGlobalLevel(zerolog.TraceLevel)
tempStdErrFileName := filepath.Join(t.TempDir(), uuid.NewString())
tempStdErr, err := os.Create(tempStdErrFileName)
require.NoError(t, err)

os.Stderr = tempStdErr
return tempStdErrFileName
}
9 changes: 5 additions & 4 deletions internal/cmd/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/spf13/cobra"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/commands"
"github.com/authzed/zed/internal/console"
"github.com/authzed/zed/internal/printers"
"github.com/authzed/zed/internal/storage"
Expand All @@ -35,15 +36,15 @@ var contextListCmd = &cobra.Command{
Use: "list",
Short: "Lists all available contexts",
Aliases: []string{"ls"},
Args: cobra.ExactArgs(0),
Args: commands.ValidationWrapper(cobra.ExactArgs(0)),
ValidArgsFunction: cobra.NoFileCompletions,
RunE: contextListCmdFunc,
}

var contextSetCmd = &cobra.Command{
Use: "set <name> <endpoint> <api-token>",
Short: "Creates or overwrite a context",
Args: cobra.ExactArgs(3),
Args: commands.ValidationWrapper(cobra.ExactArgs(3)),
ValidArgsFunction: cobra.NoFileCompletions,
RunE: contextSetCmdFunc,
}
Expand All @@ -52,15 +53,15 @@ var contextRemoveCmd = &cobra.Command{
Use: "remove <system>",
Short: "Removes a context",
Aliases: []string{"rm"},
Args: cobra.ExactArgs(1),
Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
ValidArgsFunction: ContextGet,
RunE: contextRemoveCmdFunc,
}

var contextUseCmd = &cobra.Command{
Use: "use <system>",
Short: "Sets a context as the current context",
Args: cobra.MaximumNArgs(1),
Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
ValidArgsFunction: ContextGet,
RunE: contextUseCmdFunc,
}
Expand Down
3 changes: 2 additions & 1 deletion internal/cmd/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/authzed/spicedb/pkg/validationfile"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/commands"
"github.com/authzed/zed/internal/decode"
"github.com/authzed/zed/internal/grpcutil"
)
Expand Down Expand Up @@ -60,7 +61,7 @@ var importCmd = &cobra.Command{
With schema definition prefix:
zed import --schema-definition-prefix=mypermsystem file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
`,
Args: cobra.ExactArgs(1),
Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
RunE: importCmdFunc,
}

Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/preview.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ var schemaCmd = &cobra.Command{

var schemaCompileCmd = &cobra.Command{
Use: "compile <file>",
Args: cobra.ExactArgs(1),
Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
Short: "Compile a schema that uses extended syntax into one that can be written to SpiceDB",
Example: `
Write to stdout:
Expand Down
6 changes: 3 additions & 3 deletions internal/cmd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {

var schemaWriteCmd = &cobra.Command{
Use: "write <file?>",
Args: cobra.MaximumNArgs(1),
Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
Short: "Write a schema file (.zed or stdin) to the current permissions system",
ValidArgsFunction: commands.FileExtensionCompletions("zed"),
RunE: schemaWriteCmdFunc,
Expand All @@ -50,15 +50,15 @@ var schemaWriteCmd = &cobra.Command{
var schemaCopyCmd = &cobra.Command{
Use: "copy <src context> <dest context>",
Short: "Copy a schema from one context into another",
Args: cobra.ExactArgs(2),
Args: commands.ValidationWrapper(cobra.ExactArgs(2)),
ValidArgsFunction: ContextGet,
RunE: schemaCopyCmdFunc,
}

var schemaDiffCmd = &cobra.Command{
Use: "diff <before file> <after file>",
Short: "Diff two schema files",
Args: cobra.ExactArgs(2),
Args: commands.ValidationWrapper(cobra.ExactArgs(2)),
RunE: schemaDiffCmdFunc,
}

Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ var validateCmd = &cobra.Command{

From a devtools instance:
zed validate https://localhost:8443/download`,
Args: cobra.MinimumNArgs(1),
Args: commands.ValidationWrapper(cobra.MinimumNArgs(1)),
ValidArgsFunction: commands.FileExtensionCompletions("zed", "yaml", "zaml"),
PreRunE: validatePreRunE,
RunE: func(cmd *cobra.Command, filenames []string) error {
Expand Down
Loading
Loading