diff --git a/source/logrepl/snapshot_test.go b/source/logrepl/snapshot_test.go index a47951fe..a817a239 100644 --- a/source/logrepl/snapshot_test.go +++ b/source/logrepl/snapshot_test.go @@ -24,37 +24,64 @@ import ( "github.com/conduitio/conduit-connector-postgres/test" sdk "github.com/conduitio/conduit-connector-sdk" "github.com/jackc/pgx/v4/pgxpool" - "github.com/matryer/is" ) -func TestLifecycle(t *testing.T) { +var ( + columns = []string{"id", "key", "column1", "column2", "column3"} + key = "key" +) + +func TestAtomicSnapshot(t *testing.T) { is := is.New(t) + pool := test.ConnectPool(context.Background(), t, test.RegularConnString) ctx := context.Background() - - pool := test.ConnectPool(ctx, t, test.RegularConnString) table := test.SetupTestTable(ctx, t, pool) - name := createTestSnapshot(t, pool) - - conn, err := pool.Acquire(ctx) - is.NoErr(err) - s, err := NewSnapshotIterator(context.Background(), conn.Conn(), SnapshotConfig{ + name := createTestSnapshot(ctx, t, pool) + s := createTestSnapshotIterator(ctx, t, pool, SnapshotConfig{ SnapshotName: name, Table: table, - Columns: []string{"id", "key", "column1", "column2", "column3"}, - KeyColumn: "key", + Columns: columns, + KeyColumn: key, }) + t.Cleanup(func() { is.NoErr(s.Teardown(ctx)) }) + + // add a record to our table after snapshot started + insertQuery := fmt.Sprintf(`INSERT INTO %s (id, column1, column2, column3) + VALUES (5, 'bizz', 456, false)`, table) + _, err := pool.Exec(ctx, insertQuery) is.NoErr(err) - t.Cleanup(conn.Release) + // assert record does not appear in snapshot + for i := 0; i < 5; i++ { + r, err := s.Next(ctx) + if err != nil { + is.True(errors.Is(err, ErrSnapshotComplete)) + is.Equal(r, sdk.Record{}) + } + } +} + +func TestSnapshotInterrupted(t *testing.T) { + is := is.New(t) + pool := test.ConnectPool(context.Background(), t, test.RegularConnString) + ctx := context.Background() + table := test.SetupTestTable(ctx, t, pool) + name := createTestSnapshot(ctx, t, pool) + s := createTestSnapshotIterator(ctx, t, pool, SnapshotConfig{ + SnapshotName: name, + Table: table, + Columns: columns, + KeyColumn: key, + }) now := time.Now() + rec, err := s.Next(ctx) is.NoErr(err) is.True(rec.CreatedAt.After(now)) is.Equal(rec.Metadata["action"], "snapshot") rec.CreatedAt = time.Time{} // reset time for comparison - is.Equal(rec, sdk.Record{ Position: sdk.Position(fmt.Sprintf("%s:0", table)), Key: sdk.StructuredData{ @@ -74,15 +101,38 @@ func TestLifecycle(t *testing.T) { is.True(errors.Is(s.Teardown(ctx), ErrSnapshotInterrupt)) } +func TestFullIteration(t *testing.T) { + is := is.New(t) + ctx := context.Background() + pool := test.ConnectPool(ctx, t, test.RegularConnString) + table := test.SetupTestTable(ctx, t, pool) + name := createTestSnapshot(ctx, t, pool) + s := createTestSnapshotIterator(ctx, t, pool, SnapshotConfig{ + SnapshotName: name, + Table: table, + Columns: columns, + KeyColumn: key, + }) + + for i := 0; i < 4; i++ { + rec, err := s.Next(ctx) + is.Equal(rec.Position, sdk.Position(fmt.Sprintf("%s:%d", table, i))) + is.NoErr(err) + } + + r, err := s.Next(ctx) + is.Equal(r, sdk.Record{}) + is.True(errors.Is(err, ErrSnapshotComplete)) + is.NoErr(s.Teardown(ctx)) +} + // createTestSnapshot starts a transaction that stays open while a snapshot test // runs. Otherwise, Postgres deletes the snapshot as soon as the transaction // commits or rolls back, and our snapshot iterator won't find a snapshot with // the specified name. // https://www.postgresql.org/docs/current/sql-set-transaction.html -func createTestSnapshot(t *testing.T, pool *pgxpool.Pool) string { - ctx := context.Background() +func createTestSnapshot(ctx context.Context, t *testing.T, pool *pgxpool.Pool) string { is := is.New(t) - conn, err := pool.Acquire(ctx) is.NoErr(err) @@ -103,3 +153,21 @@ func createTestSnapshot(t *testing.T, pool *pgxpool.Pool) string { return name } + +// creates a snapshot iterator for testing that hands its connection's cleanup. +func createTestSnapshotIterator(ctx context.Context, t *testing.T, + pool *pgxpool.Pool, cfg SnapshotConfig) *SnapshotIterator { + is := is.New(t) + + conn, err := pool.Acquire(ctx) + is.NoErr(err) + s, err := NewSnapshotIterator(context.Background(), conn.Conn(), SnapshotConfig{ + SnapshotName: cfg.SnapshotName, + Table: cfg.Table, + Columns: cfg.Columns, + KeyColumn: cfg.KeyColumn, + }) + is.NoErr(err) + t.Cleanup(conn.Release) + return s +} diff --git a/test/helper.go b/test/helper.go index b29b0863..7bb27928 100644 --- a/test/helper.go +++ b/test/helper.go @@ -100,7 +100,9 @@ func SetupTestTable(ctx context.Context, t *testing.T, conn Querier) string { } func RandomIdentifier(t *testing.T) string { - return fmt.Sprintf("conduit_%v_%d", strings.ToLower(t.Name()), time.Now().UnixMicro()%1000) + return fmt.Sprintf("conduit_%v_%d", + strings.ReplaceAll(strings.ToLower(t.Name()), "/", "_"), + time.Now().UnixMicro()%1000) } func IsPgError(is *is.I, err error, wantCode string) {