diff --git a/go.mod b/go.mod index df15f579..1644a6ab 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/authzed/zed -go 1.23.8 +go 1.24 toolchain go1.24.1 diff --git a/internal/client/client.go b/internal/client/client.go index 5a1c0b02..a90e0745 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -162,7 +162,9 @@ func tokenFromCli(cmd *cobra.Command) (storage.Token, error) { } // DefaultStorage returns the default configured config store and secret store. -func DefaultStorage() (storage.ConfigStore, storage.SecretStore) { +var DefaultStorage = defaultStorage + +func defaultStorage() (storage.ConfigStore, storage.SecretStore) { var home string if xdg := os.Getenv("XDG_CONFIG_HOME"); xdg != "" { home = filepath.Join(xdg, "zed") diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index fbf20877..5e32de44 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "path/filepath" "regexp" "strconv" "strings" @@ -33,11 +34,26 @@ import ( "github.com/authzed/zed/pkg/backupformat" ) +const ( + returnIfExists = true + doNotReturnIfExists = false +) + +// cobraRunEFunc is the signature of a cobra.Command.RunE function. +type cobraRunEFunc = func(cmd *cobra.Command, args []string) (err error) + +// withErrorHandling is a wrapper that centralizes error handling, instead of having to scatter it around the command logic. +func withErrorHandling(f cobraRunEFunc) cobraRunEFunc { + return func(cmd *cobra.Command, args []string) (err error) { + return addSizeErrInfo(f(cmd, args)) + } +} + var ( backupCmd = &cobra.Command{ Use: "backup ", Short: "Create, restore, and inspect permissions system backups", - Args: cobra.ExactArgs(1), + Args: cobra.MaximumNArgs(1), // Create used to be on the root, so add it here for back-compat. RunE: backupCreateCmdFunc, } @@ -45,8 +61,8 @@ var ( backupCreateCmd = &cobra.Command{ Use: "create ", Short: "Backup a permission system to a file", - Args: cobra.ExactArgs(1), - RunE: backupCreateCmdFunc, + Args: cobra.MaximumNArgs(1), + RunE: withErrorHandling(backupCreateCmdFunc), } backupRestoreCmd = &cobra.Command{ @@ -100,6 +116,8 @@ func registerBackupCmd(rootCmd *cobra.Command) { backupCmd.AddCommand(backupCreateCmd) registerBackupCreateFlags(backupCreateCmd) + backupCreateCmd.Flags().Uint32("page-limit", 0, "defines the number of relationships to be read by requested page during backup") + backupCmd.AddCommand(backupRestoreCmd) registerBackupRestoreFlags(backupRestoreCmd) @@ -144,24 +162,33 @@ func registerBackupCreateFlags(cmd *cobra.Command) { cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") } -func createBackupFile(filename string) (*os.File, error) { +func createBackupFile(filename string, returnIfExists bool) (*os.File, bool, error) { if filename == "-" { log.Trace().Str("filename", "- (stdout)").Send() - return os.Stdout, nil + return os.Stdout, false, nil } log.Trace().Str("filename", filename).Send() if _, err := os.Stat(filename); err == nil { - return nil, fmt.Errorf("backup file already exists: %s", filename) + if !returnIfExists { + return nil, false, fmt.Errorf("backup file already exists: %s", filename) + } + + f, err := os.OpenFile(filename, os.O_RDWR|os.O_APPEND, 0o644) + if err != nil { + return nil, false, fmt.Errorf("unable to open existing backup file: %w", err) + } + + return f, true, nil } - f, err := os.Create(filename) + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644) if err != nil { - return nil, fmt.Errorf("unable to create backup file: %w", err) + return nil, false, fmt.Errorf("unable to create backup file: %w", err) } - return f, nil + return f, false, nil } var ( @@ -238,73 +265,128 @@ func hasRelPrefix(rel *v1.Relationship, prefix string) bool { } func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { - f, err := createBackupFile(args[0]) + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + pageLimit := cobrautil.MustGetUint32(cmd, "page-limit") + + backupFileName, err := computeBackupFileName(cmd, args) if err != nil { return err } - defer func(e *error) { *e = errors.Join(*e, f.Close()) }(&err) - defer func(e *error) { *e = errors.Join(*e, f.Sync()) }(&err) - - c, err := client.NewClient(cmd) + backupFile, backupExists, err := createBackupFile(backupFileName, returnIfExists) if err != nil { - return fmt.Errorf("unable to initialize client: %w", err) + return err } - ctx := cmd.Context() - schemaResp, err := c.ReadSchema(ctx, &v1.ReadSchemaRequest{}) + defer func(e *error) { + *e = errors.Join(*e, backupFile.Sync()) + *e = errors.Join(*e, backupFile.Close()) + }(&err) + + // the goal of this file is to keep the bulk export cursor in case the process is terminated + // and we need to resume from where we left off. OCF does not support in-place record updates. + progressFile, cursor, err := openProgressFile(backupFileName, backupExists) if err != nil { - return fmt.Errorf("error reading schema: %w", addSizeErrInfo(err)) - } else if schemaResp.ReadAt == nil { - return fmt.Errorf("`backup` is not supported on this version of SpiceDB") + return err } - schema := schemaResp.SchemaText - // Remove any invalid relations generated from old, backwards-incompat - // Serverless permission systems. - if cobrautil.MustGetBool(cmd, "rewrite-legacy") { - schema = rewriteLegacy(schema) + var backupCompleted bool + defer func(e *error) { + *e = errors.Join(*e, progressFile.Sync()) + *e = errors.Join(*e, progressFile.Close()) + + if backupCompleted { + if err := os.Remove(progressFile.Name()); err != nil { + log.Warn(). + Str("progress-file", progressFile.Name()). + Msg("failed to remove progress file, consider removing it manually") + } + } + }(&err) + + c, err := client.NewClient(cmd) + if err != nil { + return fmt.Errorf("unable to initialize client: %w", err) } - // Skip any definitions without the provided prefix - prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") - if prefixFilter != "" { - schema, err = filterSchemaDefs(schema, prefixFilter) + var zedToken *v1.ZedToken + var encoder *backupformat.Encoder + if backupExists { + encoder, err = backupformat.NewEncoderForExisting(backupFile) + if err != nil { + return fmt.Errorf("error creating backup file encoder: %w", err) + } + } else { + encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile) if err != nil { return err } } - encoder, err := backupformat.NewEncoder(f, schema, schemaResp.ReadAt) - if err != nil { - return fmt.Errorf("error creating backup file encoder: %w", err) - } defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err) - relationshipStream, err := c.BulkExportRelationships(ctx, &v1.BulkExportRelationshipsRequest{ - Consistency: &v1.Consistency{ + if zedToken == nil && cursor == nil { + return errors.New("malformed existing backup, consider recreating it") + } + + req := &v1.BulkExportRelationshipsRequest{ + OptionalLimit: pageLimit, + OptionalCursor: cursor, + } + + // if a cursor is present, zedtoken is not needed (it is already in the cursor) + if zedToken != nil { + req.Consistency = &v1.Consistency{ Requirement: &v1.Consistency_AtExactSnapshot{ - AtExactSnapshot: schemaResp.ReadAt, + AtExactSnapshot: zedToken, }, - }, - }) + } + } + + ctx := cmd.Context() + relationshipStream, err := c.BulkExportRelationships(ctx, req) if err != nil { - return fmt.Errorf("error exporting relationships: %w", addSizeErrInfo(err)) + return fmt.Errorf("error exporting relationships: %w", err) } relationshipReadStart := time.Now() - + tick := time.Tick(5 * time.Second) bar := console.CreateProgressBar("processing backup") - var relsEncoded, relsProcessed uint + var relsFilteredOut, relsProcessed uint64 + defer func() { + _ = bar.Finish() + + evt := log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)) + if isCanceled(err) { + evt.Msg("backup canceled - resume by restarting the backup command") + } else if err != nil { + evt.Msg("backup failed") + } else { + evt.Msg("finished backup") + } + }() + for { if err := ctx.Err(); err != nil { + if isCanceled(err) { + return context.Canceled + } + return fmt.Errorf("aborted backup: %w", err) } relsResp, err := relationshipStream.Recv() if err != nil { + if isCanceled(err) { + return context.Canceled + } + if !errors.Is(err, io.EOF) { - return fmt.Errorf("error receiving relationships: %w", addSizeErrInfo(err)) + return fmt.Errorf("error receiving relationships: %w", err) } break } @@ -314,37 +396,169 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { if err := encoder.Append(rel); err != nil { return fmt.Errorf("error storing relationship: %w", err) } - relsEncoded++ - - if relsEncoded%100_000 == 0 && !isatty.IsTerminal(os.Stderr.Fd()) { - log.Trace(). - Uint("encoded", relsEncoded). - Uint("processed", relsProcessed). - Msg("backup progress") - } + } else { + relsFilteredOut++ } + relsProcessed++ if err := bar.Add(1); err != nil { return fmt.Errorf("error incrementing progress bar: %w", err) } + + // progress fallback in case there is no TTY + if !isatty.IsTerminal(os.Stderr.Fd()) { + select { + case <-tick: + log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)). + Msg("backup progress") + default: + } + } + } + + if err := writeProgress(progressFile, relsResp); err != nil { + return err } } - totalTime := time.Since(relationshipReadStart) - if err := bar.Finish(); err != nil { - return fmt.Errorf("error finalizing progress bar: %w", err) + backupCompleted = true + return nil +} + +// encoderForNewBackup creates a new encoder for a new zed backup file. It returns the ZedToken at which the backup +// must be taken. +func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.File) (*backupformat.Encoder, *v1.ZedToken, error) { + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + + schemaResp, err := c.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}) + if err != nil { + return nil, nil, fmt.Errorf("error reading schema: %w", err) + } + if schemaResp.ReadAt == nil { + return nil, nil, fmt.Errorf("`backup` is not supported on this version of SpiceDB") } + schema := schemaResp.SchemaText - log.Info(). - Uint("encoded", relsEncoded). - Uint("processed", relsProcessed). - Uint64("perSecond", perSec(uint64(relsProcessed), totalTime)). - Stringer("duration", totalTime). - Msg("finished backup") + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + + if prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return nil, nil, err + } + } + + zedToken := schemaResp.ReadAt + + encoder, err := backupformat.NewEncoder(backupFile, schema, zedToken) + if err != nil { + return nil, nil, fmt.Errorf("error creating backup file encoder: %w", err) + } + + return encoder, zedToken, nil +} + +func writeProgress(progressFile *os.File, relsResp *v1.BulkExportRelationshipsResponse) error { + err := progressFile.Truncate(0) + if err != nil { + return fmt.Errorf("unable to truncate backup progress file: %w", err) + } + + _, err = progressFile.Seek(0, 0) + if err != nil { + return fmt.Errorf("unable to seek backup progress file: %w", err) + } + + _, err = progressFile.WriteString(relsResp.AfterResultCursor.Token) + if err != nil { + return fmt.Errorf("unable to write result cursor to backup progress file: %w", err) + } return nil } +// openProgressFile returns the progress marker file and the stored progress cursor if it exists, or creates +// a new one if it does not exist. If the backup file exists, but the progress marker does not, it will return an error. +// +// The progress marker file keeps track of the last successful cursor received from the server, and is used to resume +// backups in case of failure. +func openProgressFile(backupFileName string, backupAlreadyExisted bool) (*os.File, *v1.Cursor, error) { + var cursor *v1.Cursor + progressFileName := toLockFileName(backupFileName) + var progressFile *os.File + // if a backup existed + var fileMode int + readCursor, err := os.ReadFile(progressFileName) + if backupAlreadyExisted && (os.IsNotExist(err) || len(readCursor) == 0) { + return nil, nil, fmt.Errorf("backup file %s already exists", backupFileName) + } else if backupAlreadyExisted && err == nil { + cursor = &v1.Cursor{ + Token: string(readCursor), + } + + // if backup existed and there is a progress marker, the latter should not be truncated to make sure the + // cursor stays around in case of a failure before we even start ingesting from bulk export + fileMode = os.O_WRONLY | os.O_CREATE + log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume") + } else { + // if a backup did not exist, make sure to truncate the progress file + fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC + } + + progressFile, err = os.OpenFile(progressFileName, fileMode, 0o644) + if err != nil { + return nil, nil, err + } + + return progressFile, cursor, nil +} + +func toLockFileName(backupFileName string) string { + return backupFileName + ".lock" +} + +// computeBackupFileName computes the backup file name based. +// If no file name is provided, it derives a backup on the current context +func computeBackupFileName(cmd *cobra.Command, args []string) (string, error) { + if len(args) > 0 { + return args[0], nil + } + + configStore, secretStore := client.DefaultStorage() + token, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) + if err != nil { + return "", fmt.Errorf("failed to determine current zed context: %w", err) + } + + ex, err := os.Executable() + if err != nil { + return "", err + } + exPath := filepath.Dir(ex) + + backupFileName := filepath.Join(exPath, token.Name+".zedbackup") + + return backupFileName, nil +} + +func isCanceled(err error) bool { + if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled { + return true + } + + return errors.Is(err, context.Canceled) +} + func openRestoreFile(filename string) (*os.File, int64, error) { if filename == "" { log.Trace().Str("filename", "(stdin)").Send() @@ -485,7 +699,7 @@ func backupRedactCmdFunc(cmd *cobra.Command, args []string) error { defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) filename := args[0] + ".redacted" - writer, err := createBackupFile(filename) + writer, _, err := createBackupFile(filename, doNotReturnIfExists) if err != nil { return err } @@ -646,8 +860,8 @@ func addSizeErrInfo(err error) error { return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) } - necessaryByteCount, err := strconv.Atoi(matches[1]) - if err != nil { + necessaryByteCount, atoiErr := strconv.Atoi(matches[1]) + if atoiErr != nil { return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) } diff --git a/internal/cmd/backup_test.go b/internal/cmd/backup_test.go index 3ac82f30..81be67f2 100644 --- a/internal/cmd/backup_test.go +++ b/internal/cmd/backup_test.go @@ -1,8 +1,9 @@ package cmd import ( - "context" + "encoding/json" "errors" + "fmt" "os" "path/filepath" "strings" @@ -15,9 +16,11 @@ import ( "google.golang.org/grpc/status" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/genutil/mapz" "github.com/authzed/spicedb/pkg/tuple" "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/storage" zedtesting "github.com/authzed/zed/internal/testing" ) @@ -270,16 +273,15 @@ func TestBackupParseSchemaCmdFunc(t *testing.T) { func TestBackupCreateCmdFunc(t *testing.T) { cmd := zedtesting.CreateTestCobraCommandWithFlagValue(t, zedtesting.StringFlag{FlagName: "prefix-filter"}, - zedtesting.BoolFlag{FlagName: "rewrite-legacy"}) - f := filepath.Join(os.TempDir(), uuid.NewString()) - _, err := os.Stat(f) - require.Error(t, err) - defer func() { - _ = os.Remove(f) - }() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + zedtesting.BoolFlag{FlagName: "rewrite-legacy"}, + zedtesting.UintFlag32{FlagName: "page-limit"}, + zedtesting.StringFlag{FlagName: "token"}, + zedtesting.StringFlag{FlagName: "certificate-path"}, + zedtesting.StringFlag{FlagName: "endpoint"}, + zedtesting.BoolFlag{FlagName: "insecure"}, + zedtesting.BoolFlag{FlagName: "no-verify-ca"}) + + ctx := t.Context() srv := zedtesting.NewTestServer(ctx, t) go func() { require.NoError(t, srv.Run(ctx)) @@ -288,9 +290,9 @@ func TestBackupCreateCmdFunc(t *testing.T) { require.NoError(t, err) originalClient := client.NewClient - defer func() { + t.Cleanup(func() { client.NewClient = originalClient - }() + }) client.NewClient = zedtesting.ClientFromConn(conn) @@ -300,32 +302,204 @@ func TestBackupCreateCmdFunc(t *testing.T) { _, err = c.WriteSchema(ctx, &v1.WriteSchemaRequest{Schema: testSchema}) require.NoError(t, err) - testRel := "test/resource:1#reader@test/user:1" - resp, err := c.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{ - Updates: []*v1.RelationshipUpdate{ - { - Operation: v1.RelationshipUpdate_OPERATION_TOUCH, - Relationship: tuple.MustParseV1Rel(testRel), + update := &v1.WriteRelationshipsRequest{} + testRel := "test/resource:1#reader@test/user:%d" + expectedRels := make([]string, 0, 100) + for i := range 100 { + relString := fmt.Sprintf(testRel, i) + update.Updates = append(update.Updates, &v1.RelationshipUpdate{ + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + Relationship: tuple.MustParseV1Rel(relString), + }) + expectedRels = append(expectedRels, relString) + } + resp, err := c.WriteRelationships(ctx, update) + require.NoError(t, err) + + t.Run("successful backup", func(t *testing.T) { + f := filepath.Join(t.TempDir(), uuid.NewString()) + err = backupCreateCmdFunc(cmd, []string{f}) + require.NoError(t, err) + + validateBackup(t, f, testSchema, resp.WrittenAt, expectedRels) + // validate progress file is deleted after successful backup + require.NoFileExists(t, toLockFileName(f)) + }) + + t.Run("fails if backup without progress file exists", func(t *testing.T) { + tempFile := filepath.Join(t.TempDir(), uuid.NewString()) + _, err := os.Create(tempFile) + require.NoError(t, err) + + err = backupCreateCmdFunc(cmd, []string{tempFile}) + require.ErrorContains(t, err, "already exists") + }) + + t.Run("derives backup file name from context if not provided", func(t *testing.T) { + ds := client.DefaultStorage + t.Cleanup(func() { + client.DefaultStorage = ds + }) + + cfg := storage.Config{CurrentToken: "my-test"} + cfgBytes, err := json.Marshal(cfg) + require.NoError(t, err) + + testContextPath := filepath.Join(t.TempDir(), "config.json") + err = os.WriteFile(testContextPath, cfgBytes, 0o600) + require.NoError(t, err) + + name := uuid.NewString() + client.DefaultStorage = func() (storage.ConfigStore, storage.SecretStore) { + return &testConfigStore{currentToken: name}, + &testSecretStore{token: storage.Token{Name: name}} + } + err = backupCreateCmdFunc(cmd, nil) + require.NoError(t, err) + + currentPath, err := os.Executable() + require.NoError(t, err) + exPath := filepath.Dir(currentPath) + expectedBackupFile := filepath.Join(exPath, name+".zedbackup") + require.FileExists(t, expectedBackupFile) + validateBackup(t, expectedBackupFile, testSchema, resp.WrittenAt, expectedRels) + }) + + t.Run("truncates progress marker if it existed but backup did not", func(t *testing.T) { + streamClient, err := c.BulkExportRelationships(ctx, &v1.BulkExportRelationshipsRequest{ + Consistency: &v1.Consistency{ + Requirement: &v1.Consistency_AtExactSnapshot{ + AtExactSnapshot: resp.WrittenAt, + }, }, - }, + OptionalLimit: 1, + }) + require.NoError(t, err) + + streamResp, err := streamClient.Recv() + require.NoError(t, err) + _ = streamClient.CloseSend() + + f := filepath.Join(t.TempDir(), uuid.NewString()) + lockFileName := toLockFileName(f) + err = os.WriteFile(lockFileName, []byte(streamResp.AfterResultCursor.Token), 0o600) + require.NoError(t, err) + + err = backupCreateCmdFunc(cmd, []string{f}) + require.NoError(t, err) + + // we know it did its work because it imported the 100 relationships regardless of the progress file + validateBackup(t, f, testSchema, resp.WrittenAt, expectedRels) + require.NoFileExists(t, toLockFileName(f)) }) - require.NoError(t, err) - err = backupCreateCmdFunc(cmd, []string{f}) - require.NoError(t, err) + t.Run("resumes backup if marker file exists", func(t *testing.T) { + streamClient, err := c.BulkExportRelationships(ctx, &v1.BulkExportRelationshipsRequest{ + Consistency: &v1.Consistency{ + Requirement: &v1.Consistency_AtExactSnapshot{ + AtExactSnapshot: resp.WrittenAt, + }, + }, + OptionalLimit: 90, + }) + require.NoError(t, err) + + streamResp, err := streamClient.Recv() + require.NoError(t, err) + _ = streamClient.CloseSend() - d, closer, err := decoderFromArgs(f) + f := filepath.Join(t.TempDir(), uuid.NewString()) + + // do an initial backup to have the OCF metadata in place, it will also import the 100 rels + err = backupCreateCmdFunc(cmd, []string{f}) + require.NoError(t, err) + require.FileExists(t, f) + + lockFileName := toLockFileName(f) + err = os.WriteFile(lockFileName, []byte(streamResp.AfterResultCursor.Token), 0o600) + require.NoError(t, err) + + // run backup again, this time with an existing backup file and progress file + err = backupCreateCmdFunc(cmd, []string{f}) + require.NoError(t, err) + require.NoFileExists(t, toLockFileName(f)) + + // we know it did its work because we created a progress file at relationship 90, so we will get + // a backup with 100 rels from the original import + the last 10 rels repeated again (110 in total) + validationFunc := func(t *testing.T, expected, received []string) { + require.Len(t, received, 110) + receivedSet := mapz.NewSet(received...) + expectedSet := mapz.NewSet(expected...) + + require.Equal(t, 100, receivedSet.Len()) + require.True(t, receivedSet.Equal(expectedSet)) + + for i, s := range received[100:] { + require.Equal(t, fmt.Sprintf(testRel, i+90), s) + } + } + validateBackupWithFunc(t, f, testSchema, resp.WrittenAt, expectedRels, validationFunc) + }) +} + +type testConfigStore struct { + storage.ConfigStore + currentToken string +} + +func (tcs testConfigStore) Get() (storage.Config, error) { + return storage.Config{CurrentToken: tcs.currentToken}, nil +} + +func (tcs testConfigStore) Exists() (bool, error) { + return true, nil +} + +type testSecretStore struct { + storage.SecretStore + token storage.Token +} + +func (tss testSecretStore) Get() (storage.Secrets, error) { + return storage.Secrets{Tokens: []storage.Token{tss.token}}, nil +} + +func validateBackup(t *testing.T, backupFileName string, schema string, token *v1.ZedToken, expected []string) { + t.Helper() + + f := func(t *testing.T, expected, received []string) { + require.ElementsMatch(t, expected, received) + } + + validateBackupWithFunc(t, backupFileName, schema, token, expected, f) +} + +func validateBackupWithFunc(t *testing.T, backupFileName string, schema string, token *v1.ZedToken, expected []string, + validateRels func(t *testing.T, expected, received []string), +) { + t.Helper() + + d, closer, err := decoderFromArgs(backupFileName) require.NoError(t, err) - defer func() { + t.Cleanup(func() { _ = d.Close() _ = closer.Close() - }() + }) - require.Equal(t, testSchema, d.Schema()) - rel, err := d.Next() - require.NoError(t, err) - require.Equal(t, testRel, tuple.MustV1StringRelationship(rel)) - require.Equal(t, resp.WrittenAt.Token, d.ZedToken().Token) + require.Equal(t, schema, d.Schema()) + require.Equal(t, token.Token, d.ZedToken().Token) + var received []string + for { + rel, err := d.Next() + if rel == nil { + break + } + + require.NoError(t, err) + received = append(received, tuple.MustV1StringRelationship(rel)) + } + + validateRels(t, expected, received) } func TestBackupRestoreCmdFunc(t *testing.T) { @@ -340,8 +514,7 @@ func TestBackupRestoreCmdFunc(t *testing.T) { ) backupName := createTestBackup(t, testSchema, testRelationships) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() srv := zedtesting.NewTestServer(ctx, t) go func() { require.NoError(t, srv.Run(ctx)) diff --git a/internal/cmd/import_test.go b/internal/cmd/import_test.go index f23488e3..071e5027 100644 --- a/internal/cmd/import_test.go +++ b/internal/cmd/import_test.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "path/filepath" "testing" @@ -27,8 +26,7 @@ func TestImportCmdHappyPath(t *testing.T) { f := filepath.Join("import-test", "happy-path-validation-file.yaml") // Set up client - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() srv := zedtesting.NewTestServer(ctx, t) go func() { require.NoError(srv.Run(ctx)) diff --git a/internal/cmd/restorer_test.go b/internal/cmd/restorer_test.go index 255c25cb..fd007376 100644 --- a/internal/cmd/restorer_test.go +++ b/internal/cmd/restorer_test.go @@ -157,7 +157,7 @@ func TestRestorer(t *testing.T) { } r := newRestorer(testSchema, d, c, tt.prefixFilter, tt.batchSize, tt.batchesPerTransaction, tt.conflictStrategy, tt.disableRetryErrors, 0*time.Second) - err = r.restoreFromDecoder(context.Background()) + err = r.restoreFromDecoder(t.Context()) if expectsError != nil || (expectedConflicts > 0 && tt.conflictStrategy == Fail) { require.ErrorIs(err, expectsError) return diff --git a/internal/cmd/schema_test.go b/internal/cmd/schema_test.go index 2a5908ec..e118f65b 100644 --- a/internal/cmd/schema_test.go +++ b/internal/cmd/schema_test.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "testing" "github.com/stretchr/testify/require" @@ -59,7 +58,7 @@ func TestDeterminePrefixForSchema(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - found, err := determinePrefixForSchema(context.Background(), test.specifiedPrefix, nil, &test.existingSchema) + found, err := determinePrefixForSchema(t.Context(), test.specifiedPrefix, nil, &test.existingSchema) require.NoError(t, err) require.Equal(t, test.expectedPrefix, found) }) diff --git a/internal/commands/permission_test.go b/internal/commands/permission_test.go index c3838c87..55021989 100644 --- a/internal/commands/permission_test.go +++ b/internal/commands/permission_test.go @@ -103,7 +103,7 @@ func TestCheckErrorWithInvalidDebugInformation(t *testing.T) { } func TestLookupResourcesCommand(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() srv := zedtesting.NewTestServer(ctx, t) go func() { diff --git a/internal/commands/relationship_test.go b/internal/commands/relationship_test.go index afcbbebe..70cc25cd 100644 --- a/internal/commands/relationship_test.go +++ b/internal/commands/relationship_test.go @@ -670,7 +670,7 @@ func (m *mockClient) WriteRelationships(_ context.Context, in *v1.WriteRelations } func TestBulkDeleteForcing(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() srv := zedtesting.NewTestServer(ctx, t) go func() { @@ -720,7 +720,7 @@ func TestBulkDeleteForcing(t *testing.T) { } func TestBulkDeleteManyForcing(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() srv := zedtesting.NewTestServer(ctx, t) go func() { @@ -762,7 +762,7 @@ func TestBulkDeleteManyForcing(t *testing.T) { } func TestBulkDeleteNotForcing(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() srv := zedtesting.NewTestServer(ctx, t) go func() { diff --git a/internal/grpcutil/batch_test.go b/internal/grpcutil/batch_test.go index 98e48ae4..b89ccc1b 100644 --- a/internal/grpcutil/batch_test.go +++ b/internal/grpcutil/batch_test.go @@ -67,7 +67,7 @@ func TestConcurrentBatchOrdering(t *testing.T) { return nil } - err := ConcurrentBatch(context.Background(), len(tt.items), batchSize, workers, fn) + err := ConcurrentBatch(t.Context(), len(tt.items), batchSize, workers, fn) require.NoError(err) got := make([]batch, len(gotCh)) @@ -133,7 +133,7 @@ func TestConcurrentBatch(t *testing.T) { atomic.AddInt64(&calls, 1) return nil } - err := ConcurrentBatch(context.Background(), len(tt.items), tt.batchSize, tt.workers, fn) + err := ConcurrentBatch(t.Context(), len(tt.items), tt.batchSize, tt.workers, fn) require.NoError(err) require.Equal(tt.wantCalls, int(calls)) diff --git a/internal/storage/secrets.go b/internal/storage/secrets.go index c76620df..3436f6be 100644 --- a/internal/storage/secrets.go +++ b/internal/storage/secrets.go @@ -69,7 +69,7 @@ type SecretStore interface { Put(s Secrets) error } -// Returns an empty token if no token exists. +// GetTokenIfExists returns an empty token if no token exists. func GetTokenIfExists(name string, ss SecretStore) (Token, error) { secrets, err := ss.Get() if err != nil { diff --git a/internal/testing/test_helpers.go b/internal/testing/test_helpers.go index 046563a2..d921fdbb 100644 --- a/internal/testing/test_helpers.go +++ b/internal/testing/test_helpers.go @@ -127,6 +127,6 @@ func CreateTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co } } - c.SetContext(context.Background()) + c.SetContext(t.Context()) return &c } diff --git a/pkg/backupformat/encoder.go b/pkg/backupformat/encoder.go index a8f06b83..f591ee72 100644 --- a/pkg/backupformat/encoder.go +++ b/pkg/backupformat/encoder.go @@ -11,6 +11,20 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" ) +func NewEncoderForExisting(w io.Writer) (*Encoder, error) { + avroSchema, err := avroSchemaV1() + if err != nil { + return nil, fmt.Errorf("unable to create avro schema: %w", err) + } + + enc, err := ocf.NewEncoder(avroSchema, w, ocf.WithCodec(ocf.Snappy)) + if err != nil { + return nil, fmt.Errorf("unable to create encoder: %w", err) + } + + return &Encoder{enc}, nil +} + func NewEncoder(w io.Writer, schema string, token *v1.ZedToken) (*Encoder, error) { avroSchema, err := avroSchemaV1() if err != nil {