From e70919709c3db0459465c55f53e600b0dfa3e93f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Rold=C3=A1n=20Betancort?= Date: Mon, 21 Apr 2025 19:10:26 +0100 Subject: [PATCH 1/4] make backup create command file name optional derives the file name from the zed context name --- internal/cmd/backup.go | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index fbf20877..35845f3b 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "path/filepath" "regexp" "strconv" "strings" @@ -37,7 +38,7 @@ var ( backupCmd = &cobra.Command{ Use: "backup ", Short: "Create, restore, and inspect permissions system backups", - Args: cobra.ExactArgs(1), + Args: cobra.MaximumNArgs(1), // Create used to be on the root, so add it here for back-compat. RunE: backupCreateCmdFunc, } @@ -45,7 +46,7 @@ var ( backupCreateCmd = &cobra.Command{ Use: "create ", Short: "Backup a permission system to a file", - Args: cobra.ExactArgs(1), + Args: cobra.MaximumNArgs(1), RunE: backupCreateCmdFunc, } @@ -238,7 +239,24 @@ func hasRelPrefix(rel *v1.Relationship, prefix string) bool { } func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { - f, err := createBackupFile(args[0]) + configStore, secretStore := client.DefaultStorage() + token, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) + if err != nil { + return fmt.Errorf("failed to determine current zed context: %w", err) + } + + ex, err := os.Executable() + if err != nil { + panic(err) + } + exPath := filepath.Dir(ex) + + backupFileName := filepath.Join(exPath, token.Name+".zedbackup") + if len(args) > 0 { + backupFileName = args[0] + } + + f, err := createBackupFile(backupFileName) if err != nil { return err } From 4f0f0eb556f65d1ba6d755a048fd0e53ff358ec3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Rold=C3=A1n=20Betancort?= Date: Mon, 21 Apr 2025 19:25:47 +0100 Subject: [PATCH 2/4] create backup: add support for page limit adds a flag that allows the user defining the number of relationships they'd like to be sent per bulk export page --- internal/cmd/backup.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index 35845f3b..c7b95a06 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -101,6 +101,8 @@ func registerBackupCmd(rootCmd *cobra.Command) { backupCmd.AddCommand(backupCreateCmd) registerBackupCreateFlags(backupCreateCmd) + backupCreateCmd.Flags().Uint32("page-limit", 0, "include only schema and relationships with a given prefix") + backupCmd.AddCommand(backupRestoreCmd) registerBackupRestoreFlags(backupRestoreCmd) @@ -256,6 +258,8 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { backupFileName = args[0] } + pageLimit := cobrautil.MustGetUint32(cmd, "page-limit") + f, err := createBackupFile(backupFileName) if err != nil { return err @@ -300,6 +304,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err) relationshipStream, err := c.BulkExportRelationships(ctx, &v1.BulkExportRelationshipsRequest{ + OptionalLimit: pageLimit, Consistency: &v1.Consistency{ Requirement: &v1.Consistency_AtExactSnapshot{ AtExactSnapshot: schemaResp.ReadAt, From b13830c490f4282a66c691e842d5ecc71bd65324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Rold=C3=A1n=20Betancort?= Date: Tue, 22 Apr 2025 18:29:10 +0100 Subject: [PATCH 3/4] introduce support for resuming a backup this commit changes the backup create command so that a canceled backup can be resumed. A new marker file is added with the last written bulk export cursor. This was not added to the OCF file because OCF container is not meant for in-place updates, but streaming append operations. The marker is updated every time a bulk export page is successfully written. For convenience the command now also supports creating a backup without a file name: in that case it will derive the backup file name from the zed context name. The new logic detects several scenarios like: - a backup exists, but no marker exists (meaning it completed) - backup does not exist, but marker is left behind (it gets truncated) The marker file will be removed only when the backup completes. This new logic does not guarantee relationships may not get duplicated. When terminated gracefully the system should write the last page received and close and flush the files. But if the process was abruptly terminated (e.g. SIGKILL) it could lead to relationships being written to the OCF file, but the marker not being updated. --- go.mod | 2 +- internal/client/client.go | 4 +- internal/cmd/backup.go | 316 +++++++++++++++++++------ internal/cmd/backup_test.go | 239 ++++++++++++++++--- internal/cmd/import_test.go | 4 +- internal/cmd/restorer_test.go | 2 +- internal/cmd/schema_test.go | 3 +- internal/commands/permission_test.go | 2 +- internal/commands/relationship_test.go | 6 +- internal/storage/secrets.go | 2 +- internal/testing/test_helpers.go | 2 +- pkg/backupformat/encoder.go | 14 ++ 12 files changed, 480 insertions(+), 116 deletions(-) diff --git a/go.mod b/go.mod index df15f579..1644a6ab 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/authzed/zed -go 1.23.8 +go 1.24 toolchain go1.24.1 diff --git a/internal/client/client.go b/internal/client/client.go index 5a1c0b02..a90e0745 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -162,7 +162,9 @@ func tokenFromCli(cmd *cobra.Command) (storage.Token, error) { } // DefaultStorage returns the default configured config store and secret store. -func DefaultStorage() (storage.ConfigStore, storage.SecretStore) { +var DefaultStorage = defaultStorage + +func defaultStorage() (storage.ConfigStore, storage.SecretStore) { var home string if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" { home = filepath.Join(xdg, "zed") diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index c7b95a06..fef54f77 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -34,6 +34,11 @@ import ( "github.com/authzed/zed/pkg/backupformat" ) +const ( + returnIfExists = true + doNotReturnIfExists = false +) + var ( backupCmd = &cobra.Command{ Use: "backup ", @@ -101,7 +106,7 @@ func registerBackupCmd(rootCmd *cobra.Command) { backupCmd.AddCommand(backupCreateCmd) registerBackupCreateFlags(backupCreateCmd) - backupCreateCmd.Flags().Uint32("page-limit", 0, "include only schema and relationships with a given prefix") + backupCreateCmd.Flags().Uint32("page-limit", 0, "defines the number of relationships to be read by requested page during backup") backupCmd.AddCommand(backupRestoreCmd) registerBackupRestoreFlags(backupRestoreCmd) @@ -147,24 +152,33 @@ func registerBackupCreateFlags(cmd *cobra.Command) { cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") } -func createBackupFile(filename string) (*os.File, error) { +func createBackupFile(filename string, returnIfExists bool) (*os.File, bool, error) { if filename == "-" { log.Trace().Str("filename", "- (stdout)").Send() - return os.Stdout, nil + return os.Stdout, false, nil } log.Trace().Str("filename", filename).Send() if _, err := os.Stat(filename); err == nil { - return nil, fmt.Errorf("backup file already exists: %s", filename) + if !returnIfExists { + return nil, false, fmt.Errorf("backup file already exists: %s", filename) + } + + f, err := os.OpenFile(filename, os.O_RDWR|os.O_APPEND, 0o644) + if err != nil { + return nil, false, fmt.Errorf("unable to open existing backup file: %w", err) + } + + return f, true, nil } - f, err := os.Create(filename) + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644) if err != nil { - return nil, fmt.Errorf("unable to create backup file: %w", err) + return nil, false, fmt.Errorf("unable to create backup file: %w", err) } - return f, nil + return f, false, nil } var ( @@ -241,91 +255,114 @@ func hasRelPrefix(rel *v1.Relationship, prefix string) bool { } func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { - configStore, secretStore := client.DefaultStorage() - token, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) - if err != nil { - return fmt.Errorf("failed to determine current zed context: %w", err) - } + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + pageLimit := cobrautil.MustGetUint32(cmd, "page-limit") - ex, err := os.Executable() + backupFileName, err := computeBackupFileName(cmd, args) if err != nil { - panic(err) + return err } - exPath := filepath.Dir(ex) - backupFileName := filepath.Join(exPath, token.Name+".zedbackup") - if len(args) > 0 { - backupFileName = args[0] + backupFile, backupExists, err := createBackupFile(backupFileName, returnIfExists) + if err != nil { + return err } - pageLimit := cobrautil.MustGetUint32(cmd, "page-limit") + defer func(e *error) { + *e = errors.Join(*e, backupFile.Sync()) + *e = errors.Join(*e, backupFile.Close()) + }(&err) - f, err := createBackupFile(backupFileName) + // the goal of this file is to keep the bulk export cursor in case the process is terminated + // and we need to resume from where we left off. OCF does not support in-place record updates. + progressFile, cursor, err := openProgressFile(backupFileName, backupExists) if err != nil { return err } - defer func(e *error) { *e = errors.Join(*e, f.Close()) }(&err) - defer func(e *error) { *e = errors.Join(*e, f.Sync()) }(&err) + var backupCompleted bool + defer func(e *error) { + *e = errors.Join(*e, progressFile.Sync()) + *e = errors.Join(*e, progressFile.Close()) + + if backupCompleted { + if err := os.Remove(progressFile.Name()); err != nil { + log.Warn(). + Str("progress-file", progressFile.Name()). + Msg("failed to remove progress file, consider removing it manually") + } + } + }(&err) c, err := client.NewClient(cmd) if err != nil { return fmt.Errorf("unable to initialize client: %w", err) } - ctx := cmd.Context() - schemaResp, err := c.ReadSchema(ctx, &v1.ReadSchemaRequest{}) - if err != nil { - return fmt.Errorf("error reading schema: %w", addSizeErrInfo(err)) - } else if schemaResp.ReadAt == nil { - return fmt.Errorf("`backup` is not supported on this version of SpiceDB") - } - schema := schemaResp.SchemaText - - // Remove any invalid relations generated from old, backwards-incompat - // Serverless permission systems. - if cobrautil.MustGetBool(cmd, "rewrite-legacy") { - schema = rewriteLegacy(schema) - } - - // Skip any definitions without the provided prefix - prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") - if prefixFilter != "" { - schema, err = filterSchemaDefs(schema, prefixFilter) + var zedToken *v1.ZedToken + var encoder *backupformat.Encoder + if backupExists { + encoder, err = backupformat.NewEncoderForExisting(backupFile) + if err != nil { + return fmt.Errorf("error creating backup file encoder: %w", err) + } + } else { + encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile) if err != nil { return err } } - encoder, err := backupformat.NewEncoder(f, schema, schemaResp.ReadAt) - if err != nil { - return fmt.Errorf("error creating backup file encoder: %w", err) - } defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err) + req := &v1.BulkExportRelationshipsRequest{ + OptionalLimit: pageLimit, + OptionalCursor: cursor, + } - relationshipStream, err := c.BulkExportRelationships(ctx, &v1.BulkExportRelationshipsRequest{ - OptionalLimit: pageLimit, - Consistency: &v1.Consistency{ + // if a cursor is present, zedtoken is not needed (it is already in the cursor) + if zedToken != nil { + req.Consistency = &v1.Consistency{ Requirement: &v1.Consistency_AtExactSnapshot{ - AtExactSnapshot: schemaResp.ReadAt, + AtExactSnapshot: zedToken, }, - }, - }) + } + } + + ctx := cmd.Context() + relationshipStream, err := c.BulkExportRelationships(ctx, req) if err != nil { return fmt.Errorf("error exporting relationships: %w", addSizeErrInfo(err)) } relationshipReadStart := time.Now() - + tick := time.Tick(5 * time.Second) bar := console.CreateProgressBar("processing backup") - var relsEncoded, relsProcessed uint + var relsFilteredOut, relsProcessed uint64 for { if err := ctx.Err(); err != nil { + _ = bar.Finish() + if isCanceled(err) { + log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Msg("backup canceled - resume by restarting the backup command") + return context.Canceled + } + return fmt.Errorf("aborted backup: %w", err) } relsResp, err := relationshipStream.Recv() if err != nil { + _ = bar.Finish() + if isCanceled(err) { + log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Msg("backup canceled - resume by restarting the backup command") + return context.Canceled + } + if !errors.Is(err, io.EOF) { return fmt.Errorf("error receiving relationships: %w", addSizeErrInfo(err)) } @@ -337,37 +374,178 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { if err := encoder.Append(rel); err != nil { return fmt.Errorf("error storing relationship: %w", err) } - relsEncoded++ - - if relsEncoded%100_000 == 0 && !isatty.IsTerminal(os.Stderr.Fd()) { - log.Trace(). - Uint("encoded", relsEncoded). - Uint("processed", relsProcessed). - Msg("backup progress") - } + } else { + relsFilteredOut++ } + relsProcessed++ if err := bar.Add(1); err != nil { return fmt.Errorf("error incrementing progress bar: %w", err) } + + // progress fallback in case there is no TTY + if !isatty.IsTerminal(os.Stderr.Fd()) { + select { + case <-tick: + log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Msg("backup progress") + default: + } + } } - } - totalTime := time.Since(relationshipReadStart) - if err := bar.Finish(); err != nil { - return fmt.Errorf("error finalizing progress bar: %w", err) + if err := writeProgress(progressFile, relsResp); err != nil { + return err + } } + totalTime := time.Since(relationshipReadStart) + _ = bar.Finish() + log.Info(). - Uint("encoded", relsEncoded). - Uint("processed", relsProcessed). - Uint64("perSecond", perSec(uint64(relsProcessed), totalTime)). + Uint64("processed", relsProcessed). + Uint64("filtered", relsFilteredOut). + Uint64("throughput", perSec(relsProcessed, totalTime)). Stringer("duration", totalTime). Msg("finished backup") + backupCompleted = true return nil } +// encoderForNewBackup creates a new encoder for a new zed backup file. It returns the ZedToken at which the backup +// must be taken. +func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.File) (*backupformat.Encoder, *v1.ZedToken, error) { + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + + schemaResp, err := c.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}) + if err != nil { + return nil, nil, fmt.Errorf("error reading schema: %w", addSizeErrInfo(err)) + } + if schemaResp.ReadAt == nil { + return nil, nil, fmt.Errorf("`backup` is not supported on this version of SpiceDB") + } + schema := schemaResp.SchemaText + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + + if prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return nil, nil, err + } + } + + zedToken := schemaResp.ReadAt + + encoder, err := backupformat.NewEncoder(backupFile, schema, zedToken) + if err != nil { + return nil, nil, fmt.Errorf("error creating backup file encoder: %w", err) + } + + return encoder, zedToken, nil +} + +func writeProgress(progressFile *os.File, relsResp *v1.BulkExportRelationshipsResponse) error { + err := progressFile.Truncate(0) + if err != nil { + return fmt.Errorf("unable to truncate backup progress file: %w", err) + } + + _, err = progressFile.Seek(0, 0) + if err != nil { + return fmt.Errorf("unable to seek backup progress file: %w", err) + } + + _, err = progressFile.WriteString(relsResp.AfterResultCursor.Token) + if err != nil { + return fmt.Errorf("unable to write result cursor to backup progress file: %w", err) + } + + return nil +} + +// openProgressFile returns the progress marker file and the stored progress cursor if it exists, or creates +// a new one if it does not exist. If the backup file exists, but the progress marker does not, it will return an error. +// +// The progress marker file keeps track of the last successful cursor received from the server, and is used to resume +// backups in case of failure. +func openProgressFile(backupFileName string, backupAlreadyExisted bool) (*os.File, *v1.Cursor, error) { + var cursor *v1.Cursor + progressFileName := toLockFileName(backupFileName) + var progressFile *os.File + // if a backup existed + var fileMode int + readCursor, err := os.ReadFile(progressFileName) + if os.IsNotExist(err) && backupAlreadyExisted { + return nil, nil, fmt.Errorf("backup file %s already exists", backupFileName) + } else if err == nil && backupAlreadyExisted { + cursor = &v1.Cursor{ + Token: string(readCursor), + } + + // if backup existed and there is a progress marker, the latter should not be truncated to make sure the + // cursor stays around in case of a failure before we even start ingesting from bulk export + fileMode = os.O_WRONLY | os.O_CREATE + log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume") + } else { + // if a backup did not exist, make sure to truncate the progress file + fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC + } + + progressFile, err = os.OpenFile(progressFileName, fileMode, 0o644) + if err != nil { + return nil, nil, err + } + + return progressFile, cursor, nil +} + +func toLockFileName(backupFileName string) string { + return backupFileName + ".lock" +} + +// computeBackupFileName computes the backup file name based. +// If no file name is provided, it derives a backup on the current context +func computeBackupFileName(cmd *cobra.Command, args []string) (string, error) { + if len(args) > 0 { + return args[0], nil + } + + configStore, secretStore := client.DefaultStorage() + token, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) + if err != nil { + return "", fmt.Errorf("failed to determine current zed context: %w", err) + } + + ex, err := os.Executable() + if err != nil { + return "", err + } + exPath := filepath.Dir(ex) + + backupFileName := filepath.Join(exPath, token.Name+".zedbackup") + + return backupFileName, nil +} + +func isCanceled(err error) bool { + if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled { + return true + } + + return errors.Is(err, context.Canceled) +} + func openRestoreFile(filename string) (*os.File, int64, error) { if filename == "" { log.Trace().Str("filename", "(stdin)").Send() @@ -508,7 +686,7 @@ func backupRedactCmdFunc(cmd *cobra.Command, args []string) error { defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) filename := args[0] + ".redacted" - writer, err := createBackupFile(filename) + writer, _, err := createBackupFile(filename, doNotReturnIfExists) if err != nil { return err } diff --git a/internal/cmd/backup_test.go b/internal/cmd/backup_test.go index 3ac82f30..81be67f2 100644 --- a/internal/cmd/backup_test.go +++ b/internal/cmd/backup_test.go @@ -1,8 +1,9 @@ package cmd import ( - "context" + "encoding/json" "errors" + "fmt" "os" "path/filepath" "strings" @@ -15,9 +16,11 @@ import ( "google.golang.org/grpc/status" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/genutil/mapz" "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/storage" zedtesting "github.com/authzed/zed/internal/testing" ) @@ -270,16 +273,15 @@ func TestBackupParseSchemaCmdFunc(t *testing.T) { func TestBackupCreateCmdFunc(t *testing.T) { cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t, zedtesting.StringFlag{FlagName: "prefix-filter"}, - zedtesting.BoolFlag{FlagName: "rewrite-legacy"}) - f := filepath.Join(os.TempDir(), uuid.NewString()) - _, err := os.Stat(f) - require.Error(t, err) - defer func() { - _ = os.Remove(f) - }() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + zedtesting.BoolFlag{FlagName: "rewrite-legacy"}, + zedtesting.UintFlag32{FlagName: "page-limit"}, + zedtesting.StringFlag{FlagName: "token"}, + zedtesting.StringFlag{FlagName: "certificate-path"}, + zedtesting.StringFlag{FlagName: "endpoint"}, + zedtesting.BoolFlag{FlagName: "insecure"}, + zedtesting.BoolFlag{FlagName: "no-verify-ca"}) + + ctx := t.Context() srv := zedtesting.NewTestServer(ctx, t) go func() { require.NoError(t, srv.Run(ctx)) @@ -288,9 +290,9 @@ func TestBackupCreateCmdFunc(t *testing.T) { require.NoError(t, err) originalClient := client.NewClient - defer func() { + t.Cleanup(func() { client.NewClient = originalClient - }() + }) client.NewClient = zedtesting.ClientFromConn(conn) @@ -300,32 +302,204 @@ func TestBackupCreateCmdFunc(t *testing.T) { _, err = c.WriteSchema(ctx, &v1.WriteSchemaRequest{Schema: testSchema}) require.NoError(t, err) - testRel := "test/resource:1#reader@test/user:1" - resp, err := c.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{ - Updates: []*v1.RelationshipUpdate{ - { - Operation: v1.RelationshipUpdate_OPERATION_TOUCH, - Relationship: tuple.MustParseV1Rel(testRel), + update := &v1.WriteRelationshipsRequest{} + testRel := "test/resource:1#reader@test/user:%d" + expectedRels := make([]string, 0, 100) + for i := range 100 { + relString := fmt.Sprintf(testRel, i) + update.Updates = append(update.Updates, &v1.RelationshipUpdate{ + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + Relationship: tuple.MustParseV1Rel(relString), + }) + expectedRels = append(expectedRels, relString) + } + resp, err := c.WriteRelationships(ctx, update) + require.NoError(t, err) + + t.Run("successful backup", func(t *testing.T) { + f := filepath.Join(t.TempDir(), uuid.NewString()) + err = backupCreateCmdFunc(cmd, []string{f}) + require.NoError(t, err) + + validateBackup(t, f, testSchema, resp.WrittenAt, expectedRels) + // validate progress file is deleted after successful backup + require.NoFileExists(t, toLockFileName(f)) + }) + + t.Run("fails if backup without progress file exists", func(t *testing.T) { + tempFile := filepath.Join(t.TempDir(), uuid.NewString()) + _, err := os.Create(tempFile) + require.NoError(t, err) + + err = backupCreateCmdFunc(cmd, []string{tempFile}) + require.ErrorContains(t, err, "already exists") + }) + + t.Run("derives backup file name from context if not provided", func(t *testing.T) { + ds := client.DefaultStorage + t.Cleanup(func() { + client.DefaultStorage = ds + }) + + cfg := storage.Config{CurrentToken: "my-test"} + cfgBytes, err := json.Marshal(cfg) + require.NoError(t, err) + + testContextPath := filepath.Join(t.TempDir(), "config.json") + err = os.WriteFile(testContextPath, cfgBytes, 0o600) + require.NoError(t, err) + + name := uuid.NewString() + client.DefaultStorage = func() (storage.ConfigStore, storage.SecretStore) { + return &testConfigStore{currentToken: name}, + &testSecretStore{token: storage.Token{Name: name}} + } + err = backupCreateCmdFunc(cmd, nil) + require.NoError(t, err) + + currentPath, err := os.Executable() + require.NoError(t, err) + exPath := filepath.Dir(currentPath) + expectedBackupFile := filepath.Join(exPath, name+".zedbackup") + require.FileExists(t, expectedBackupFile) + validateBackup(t, expectedBackupFile, testSchema, resp.WrittenAt, expectedRels) + }) + + t.Run("truncates progress marker if it existed but backup did not", func(t *testing.T) { + streamClient, err := c.BulkExportRelationships(ctx, &v1.BulkExportRelationshipsRequest{ + Consistency: &v1.Consistency{ + Requirement: &v1.Consistency_AtExactSnapshot{ + AtExactSnapshot: resp.WrittenAt, + }, }, - }, + OptionalLimit: 1, + }) + require.NoError(t, err) + + streamResp, err := streamClient.Recv() + require.NoError(t, err) + _ = streamClient.CloseSend() + + f := filepath.Join(t.TempDir(), uuid.NewString()) + lockFileName := toLockFileName(f) + err = os.WriteFile(lockFileName, []byte(streamResp.AfterResultCursor.Token), 0o600) + require.NoError(t, err) + + err = backupCreateCmdFunc(cmd, []string{f}) + require.NoError(t, err) + + // we know it did its work because it imported the 100 relationships regardless of the progress file + validateBackup(t, f, testSchema, resp.WrittenAt, expectedRels) + require.NoFileExists(t, toLockFileName(f)) }) - require.NoError(t, err) - err = backupCreateCmdFunc(cmd, []string{f}) - require.NoError(t, err) + t.Run("resumes backup if marker file exists", func(t *testing.T) { + streamClient, err := c.BulkExportRelationships(ctx, &v1.BulkExportRelationshipsRequest{ + Consistency: &v1.Consistency{ + Requirement: &v1.Consistency_AtExactSnapshot{ + AtExactSnapshot: resp.WrittenAt, + }, + }, + OptionalLimit: 90, + }) + require.NoError(t, err) + + streamResp, err := streamClient.Recv() + require.NoError(t, err) + _ = streamClient.CloseSend() - d, closer, err := decoderFromArgs(f) + f := filepath.Join(t.TempDir(), uuid.NewString()) + + // do an initial backup to have the OCF metadata in place, it will also import the 100 rels + err = backupCreateCmdFunc(cmd, []string{f}) + require.NoError(t, err) + require.FileExists(t, f) + + lockFileName := toLockFileName(f) + err = os.WriteFile(lockFileName, []byte(streamResp.AfterResultCursor.Token), 0o600) + require.NoError(t, err) + + // run backup again, this time with an existing backup file and progress file + err = backupCreateCmdFunc(cmd, []string{f}) + require.NoError(t, err) + require.NoFileExists(t, toLockFileName(f)) + + // we know it did its work because we created a progress file at relationship 90, so we will get + // a backup with 100 rels from the original import + the last 10 rels repeated again (110 in total) + validationFunc := func(t *testing.T, expected, received []string) { + require.Len(t, received, 110) + receivedSet := mapz.NewSet(received...) + expectedSet := mapz.NewSet(expected...) + + require.Equal(t, 100, receivedSet.Len()) + require.True(t, receivedSet.Equal(expectedSet)) + + for i, s := range received[100:] { + require.Equal(t, fmt.Sprintf(testRel, i+90), s) + } + } + validateBackupWithFunc(t, f, testSchema, resp.WrittenAt, expectedRels, validationFunc) + }) +} + +type testConfigStore struct { + storage.ConfigStore + currentToken string +} + +func (tcs testConfigStore) Get() (storage.Config, error) { + return storage.Config{CurrentToken: tcs.currentToken}, nil +} + +func (tcs testConfigStore) Exists() (bool, error) { + return true, nil +} + +type testSecretStore struct { + storage.SecretStore + token storage.Token +} + +func (tss testSecretStore) Get() (storage.Secrets, error) { + return storage.Secrets{Tokens: []storage.Token{tss.token}}, nil +} + +func validateBackup(t *testing.T, backupFileName string, schema string, token *v1.ZedToken, expected []string) { + t.Helper() + + f := func(t *testing.T, expected, received []string) { + require.ElementsMatch(t, expected, received) + } + + validateBackupWithFunc(t, backupFileName, schema, token, expected, f) +} + +func validateBackupWithFunc(t *testing.T, backupFileName string, schema string, token *v1.ZedToken, expected []string, + validateRels func(t *testing.T, expected, received []string), +) { + t.Helper() + + d, closer, err := decoderFromArgs(backupFileName) require.NoError(t, err) - defer func() { + t.Cleanup(func() { _ = d.Close() _ = closer.Close() - }() + }) - require.Equal(t, testSchema, d.Schema()) - rel, err := d.Next() - require.NoError(t, err) - require.Equal(t, testRel, tuple.MustV1StringRelationship(rel)) - require.Equal(t, resp.WrittenAt.Token, d.ZedToken().Token) + require.Equal(t, schema, d.Schema()) + require.Equal(t, token.Token, d.ZedToken().Token) + var received []string + for { + rel, err := d.Next() + if rel == nil { + break + } + + require.NoError(t, err) + received = append(received, tuple.MustV1StringRelationship(rel)) + } + + validateRels(t, expected, received) } func TestBackupRestoreCmdFunc(t *testing.T) { @@ -340,8 +514,7 @@ func TestBackupRestoreCmdFunc(t *testing.T) { ) backupName := createTestBackup(t, testSchema, testRelationships) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() srv := zedtesting.NewTestServer(ctx, t) go func() { require.NoError(t, srv.Run(ctx)) diff --git a/internal/cmd/import_test.go b/internal/cmd/import_test.go index f23488e3..071e5027 100644 --- a/internal/cmd/import_test.go +++ b/internal/cmd/import_test.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "path/filepath" "testing" @@ -27,8 +26,7 @@ func TestImportCmdHappyPath(t *testing.T) { f := filepath.Join("import-test", "happy-path-validation-file.yaml") // Set up client - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() srv := zedtesting.NewTestServer(ctx, t) go func() { require.NoError(srv.Run(ctx)) diff --git a/internal/cmd/restorer_test.go b/internal/cmd/restorer_test.go index 255c25cb..fd007376 100644 --- a/internal/cmd/restorer_test.go +++ b/internal/cmd/restorer_test.go @@ -157,7 +157,7 @@ func TestRestorer(t *testing.T) { } r := newRestorer(testSchema, d, c, tt.prefixFilter, tt.batchSize, tt.batchesPerTransaction, tt.conflictStrategy, tt.disableRetryErrors, 0*time.Second) - err = r.restoreFromDecoder(context.Background()) + err = r.restoreFromDecoder(t.Context()) if expectsError != nil || (expectedConflicts > 0 && tt.conflictStrategy == Fail) { require.ErrorIs(err, expectsError) return diff --git a/internal/cmd/schema_test.go b/internal/cmd/schema_test.go index 2a5908ec..e118f65b 100644 --- a/internal/cmd/schema_test.go +++ b/internal/cmd/schema_test.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "testing" "github.com/stretchr/testify/require" @@ -59,7 +58,7 @@ func TestDeterminePrefixForSchema(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - found, err := determinePrefixForSchema(context.Background(), test.specifiedPrefix, nil, &test.existingSchema) + found, err := determinePrefixForSchema(t.Context(), test.specifiedPrefix, nil, &test.existingSchema) require.NoError(t, err) require.Equal(t, test.expectedPrefix, found) }) diff --git a/internal/commands/permission_test.go b/internal/commands/permission_test.go index c3838c87..55021989 100644 --- a/internal/commands/permission_test.go +++ b/internal/commands/permission_test.go @@ -103,7 +103,7 @@ func TestCheckErrorWithInvalidDebugInformation(t *testing.T) { } func TestLookupResourcesCommand(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() srv := zedtesting.NewTestServer(ctx, t) go func() { diff --git a/internal/commands/relationship_test.go b/internal/commands/relationship_test.go index afcbbebe..70cc25cd 100644 --- a/internal/commands/relationship_test.go +++ b/internal/commands/relationship_test.go @@ -670,7 +670,7 @@ func (m *mockClient) WriteRelationships(_ context.Context, in *v1.WriteRelations } func TestBulkDeleteForcing(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() srv := zedtesting.NewTestServer(ctx, t) go func() { @@ -720,7 +720,7 @@ func TestBulkDeleteForcing(t *testing.T) { } func TestBulkDeleteManyForcing(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() srv := zedtesting.NewTestServer(ctx, t) go func() { @@ -762,7 +762,7 @@ func TestBulkDeleteManyForcing(t *testing.T) { } func TestBulkDeleteNotForcing(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() srv := zedtesting.NewTestServer(ctx, t) go func() { diff --git a/internal/storage/secrets.go b/internal/storage/secrets.go index c76620df..3436f6be 100644 --- a/internal/storage/secrets.go +++ b/internal/storage/secrets.go @@ -69,7 +69,7 @@ type SecretStore interface { Put(s Secrets) error } -// Returns an empty token if no token exists. +// GetTokenIfExists returns an empty token if no token exists. func GetTokenIfExists(name string, ss SecretStore) (Token, error) { secrets, err := ss.Get() if err != nil { diff --git a/internal/testing/test_helpers.go b/internal/testing/test_helpers.go index 046563a2..d921fdbb 100644 --- a/internal/testing/test_helpers.go +++ b/internal/testing/test_helpers.go @@ -127,6 +127,6 @@ func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co } } - c.SetContext(context.Background()) + c.SetContext(t.Context()) return &c } diff --git a/pkg/backupformat/encoder.go b/pkg/backupformat/encoder.go index a8f06b83..f591ee72 100644 --- a/pkg/backupformat/encoder.go +++ b/pkg/backupformat/encoder.go @@ -11,6 +11,20 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" ) +func NewEncoderForExisting(w io.Writer) (*Encoder, error) { + avroSchema, err := avroSchemaV1() + if err != nil { + return nil, fmt.Errorf("unable to create avro schema: %w", err) + } + + enc, err := ocf.NewEncoder(avroSchema, w, ocf.WithCodec(ocf.Snappy)) + if err != nil { + return nil, fmt.Errorf("unable to create encoder: %w", err) + } + + return &Encoder{enc}, nil +} + func NewEncoder(w io.Writer, schema string, token *v1.ZedToken) (*Encoder, error) { avroSchema, err := avroSchemaV1() if err != nil { From cea939a4f104ec83c9d8488586e570b4bc696b10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Rold=C3=A1n=20Betancort?= Date: Wed, 23 Apr 2025 12:15:17 +0100 Subject: [PATCH 4/4] various quality of life improvements from last commit - print backup stats always, regardless termination reason - improve handling of ResourceExhausted error, instead of fixing all exit paths, wrap the command function so the error is always evaluated. - fixes malformed error message due to error variable overlap - made resilient to existing backup with empty marker file --- internal/cmd/backup.go | 69 ++++++++++++++++++++------------- internal/grpcutil/batch_test.go | 4 +- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index fef54f77..5e32de44 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -39,6 +39,16 @@ const ( doNotReturnIfExists = false ) +// cobraRunEFunc is the signature of a cobra.Command.RunE function. +type cobraRunEFunc = func(cmd *cobra.Command, args []string) (err error) + +// withErrorHandling is a wrapper that centralizes error handling, instead of having to scatter it around the command logic. +func withErrorHandling(f cobraRunEFunc) cobraRunEFunc { + return func(cmd *cobra.Command, args []string) (err error) { + return addSizeErrInfo(f(cmd, args)) + } +} + var ( backupCmd = &cobra.Command{ Use: "backup ", @@ -52,7 +62,7 @@ var ( Use: "create ", Short: "Backup a permission system to a file", Args: cobra.MaximumNArgs(1), - RunE: backupCreateCmdFunc, + RunE: withErrorHandling(backupCreateCmdFunc), } backupRestoreCmd = &cobra.Command{ @@ -314,6 +324,11 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { } defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err) + + if zedToken == nil && cursor == nil { + return errors.New("malformed existing backup, consider recreating it") + } + req := &v1.BulkExportRelationshipsRequest{ OptionalLimit: pageLimit, OptionalCursor: cursor, @@ -331,21 +346,33 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { ctx := cmd.Context() relationshipStream, err := c.BulkExportRelationships(ctx, req) if err != nil { - return fmt.Errorf("error exporting relationships: %w", addSizeErrInfo(err)) + return fmt.Errorf("error exporting relationships: %w", err) } relationshipReadStart := time.Now() tick := time.Tick(5 * time.Second) bar := console.CreateProgressBar("processing backup") var relsFilteredOut, relsProcessed uint64 + defer func() { + _ = bar.Finish() + + evt := log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)) + if isCanceled(err) { + evt.Msg("backup canceled - resume by restarting the backup command") + } else if err != nil { + evt.Msg("backup failed") + } else { + evt.Msg("finished backup") + } + }() + for { if err := ctx.Err(); err != nil { - _ = bar.Finish() if isCanceled(err) { - log.Info(). - Uint64("filtered", relsFilteredOut). - Uint64("processed", relsProcessed). - Msg("backup canceled - resume by restarting the backup command") return context.Canceled } @@ -354,17 +381,12 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { relsResp, err := relationshipStream.Recv() if err != nil { - _ = bar.Finish() if isCanceled(err) { - log.Info(). - Uint64("filtered", relsFilteredOut). - Uint64("processed", relsProcessed). - Msg("backup canceled - resume by restarting the backup command") return context.Canceled } if !errors.Is(err, io.EOF) { - return fmt.Errorf("error receiving relationships: %w", addSizeErrInfo(err)) + return fmt.Errorf("error receiving relationships: %w", err) } break } @@ -391,6 +413,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { Uint64("filtered", relsFilteredOut). Uint64("processed", relsProcessed). Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)). Msg("backup progress") default: } @@ -402,16 +425,6 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { } } - totalTime := time.Since(relationshipReadStart) - _ = bar.Finish() - - log.Info(). - Uint64("processed", relsProcessed). - Uint64("filtered", relsFilteredOut). - Uint64("throughput", perSec(relsProcessed, totalTime)). - Stringer("duration", totalTime). - Msg("finished backup") - backupCompleted = true return nil } @@ -423,7 +436,7 @@ func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.Fil schemaResp, err := c.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}) if err != nil { - return nil, nil, fmt.Errorf("error reading schema: %w", addSizeErrInfo(err)) + return nil, nil, fmt.Errorf("error reading schema: %w", err) } if schemaResp.ReadAt == nil { return nil, nil, fmt.Errorf("`backup` is not supported on this version of SpiceDB") @@ -486,9 +499,9 @@ func openProgressFile(backupFileName string, backupAlreadyExisted bool) (*os.Fil // if a backup existed var fileMode int readCursor, err := os.ReadFile(progressFileName) - if os.IsNotExist(err) && backupAlreadyExisted { + if backupAlreadyExisted && (os.IsNotExist(err) || len(readCursor) == 0) { return nil, nil, fmt.Errorf("backup file %s already exists", backupFileName) - } else if err == nil && backupAlreadyExisted { + } else if backupAlreadyExisted && err == nil { cursor = &v1.Cursor{ Token: string(readCursor), } @@ -847,8 +860,8 @@ func addSizeErrInfo(err error) error { return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) } - necessaryByteCount, err := strconv.Atoi(matches[1]) - if err != nil { + necessaryByteCount, atoiErr := strconv.Atoi(matches[1]) + if atoiErr != nil { return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) } diff --git a/internal/grpcutil/batch_test.go b/internal/grpcutil/batch_test.go index 98e48ae4..b89ccc1b 100644 --- a/internal/grpcutil/batch_test.go +++ b/internal/grpcutil/batch_test.go @@ -67,7 +67,7 @@ func TestConcurrentBatchOrdering(t *testing.T) { return nil } - err := ConcurrentBatch(context.Background(), len(tt.items), batchSize, workers, fn) + err := ConcurrentBatch(t.Context(), len(tt.items), batchSize, workers, fn) require.NoError(err) got := make([]batch, len(gotCh)) @@ -133,7 +133,7 @@ func TestConcurrentBatch(t *testing.T) { atomic.AddInt64(&calls, 1) return nil } - err := ConcurrentBatch(context.Background(), len(tt.items), tt.batchSize, tt.workers, fn) + err := ConcurrentBatch(t.Context(), len(tt.items), tt.batchSize, tt.workers, fn) require.NoError(err) require.Equal(tt.wantCalls, int(calls))