diff --git a/activity/IActivity.go b/activity/IActivity.go new file mode 100644 index 000000000..245ec2e8e --- /dev/null +++ b/activity/IActivity.go @@ -0,0 +1,21 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package activity + +import "context" + +type Activity interface { + Transaction(ctx context.Context) error + Compensation(ctx context.Context) error +} diff --git a/common/constants/constants.go b/common/constants/constants.go index c0eaa8768..31eb0f2a3 100644 --- a/common/constants/constants.go +++ b/common/constants/constants.go @@ -101,4 +101,7 @@ const ( // Metadata table names SMT_JOB_TABLE string = "SMT_JOB" SMT_RESOURCE_TABLE string = "SMT_RESOURCE" + + // Reverse Replication + REVERSE_REPLICATION_JOB_TYPE string = "reverse-replication" ) diff --git a/reverserepl/activity/create_smt_job_entry.go b/reverserepl/activity/create_smt_job_entry.go new file mode 100644 index 000000000..9ef0d6082 --- /dev/null +++ b/reverserepl/activity/create_smt_job_entry.go @@ -0,0 +1,61 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package activity + +import ( + "context" + "fmt" + + "cloud.google.com/go/spanner" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/dao" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" +) + +type CreateSmtJobEntryInput struct { + SmtJobId string + JobName string + SpannerProjectId string + InstanceId string + DatabaseId string + JobData string +} + +type CreateSmtJobEntry struct { + Input *CreateSmtJobEntryInput + DAO dao.DAO + SpA spanneraccessor.SpannerAccessor +} + +// This creates a reverse replication entry in the SMT job table. +func (p *CreateSmtJobEntry) Transaction(ctx context.Context) error { + input := p.Input + dialect, err := p.SpA.GetDatabaseDialect(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", input.SpannerProjectId, input.InstanceId, input.DatabaseId)) + if err != nil { + return fmt.Errorf("could not fetch database dialect: %v", err) + } + logger.Log.Debug(fmt.Sprintf("found database dialect: %s", dialect)) + jobData := spanner.NullJSON{Valid: true, Value: input.JobData} + err = p.DAO.InsertSMTJobEntry(ctx, input.SmtJobId, input.JobName, constants.REVERSE_REPLICATION_JOB_TYPE, dialect, input.DatabaseId, jobData) + if err != nil { + return err + } + logger.Log.Debug("Created entry SMT Job entry") + return nil +} + +func (p *CreateSmtJobEntry) Compensation(ctx context.Context) error { + return nil +} diff --git a/reverserepl/activity/create_smt_job_entry_test.go b/reverserepl/activity/create_smt_job_entry_test.go new file mode 100644 index 000000000..e28e43c26 --- /dev/null +++ b/reverserepl/activity/create_smt_job_entry_test.go @@ -0,0 +1,109 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package activity + +import ( + "context" + "fmt" + "os" + "testing" + + "cloud.google.com/go/spanner" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/dao" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +type SpannerAccessorMock struct { + spanneraccessor.SpannerAccessor +} + +var getDatabaseDialectMock func(ctx context.Context, dbURI string) (string, error) + +func (sam *SpannerAccessorMock) GetDatabaseDialect(ctx context.Context, dbURI string) (string, error) { + return getDatabaseDialectMock(ctx, dbURI) +} + +type DAOMock struct { + dao.DAOImpl +} + +var insertSMTJobEntryMock func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error + +func (dao *DAOMock) InsertSMTJobEntry(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error { + return insertSMTJobEntryMock(ctx, jobId, jobName, jobType, dialect, dbName, jobData) +} + +func TestCreateSmtJobEntryTransaction(t *testing.T) { + testCases := []struct { + name string + getDatabaseDialectMock func(ctx context.Context, dbURI string) (string, error) + insertSMTJobEntryMock func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error + expectError bool + }{ + { + name: "No errors", + getDatabaseDialectMock: func(ctx context.Context, dbURI string) (string, error) { + return "", nil + }, + insertSMTJobEntryMock: func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error { + return nil + }, + expectError: false, + }, + { + name: "Fetch Dialect error", + getDatabaseDialectMock: func(ctx context.Context, dbURI string) (string, error) { + return "", fmt.Errorf("test error") + }, + insertSMTJobEntryMock: func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error { + return nil + }, + expectError: true, + }, + { + name: "Dao error", + getDatabaseDialectMock: func(ctx context.Context, dbURI string) (string, error) { + return "", nil + }, + insertSMTJobEntryMock: func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error { + return fmt.Errorf("test error") + }, + expectError: true, + }, + } + ctx := context.Background() + createSmtJobEntry := CreateSmtJobEntry{ + Input: &CreateSmtJobEntryInput{}, + DAO: &DAOMock{}, + SpA: &SpannerAccessorMock{}, + } + for _, tc := range testCases { + getDatabaseDialectMock = tc.getDatabaseDialectMock + insertSMTJobEntryMock = tc.insertSMTJobEntryMock + err := createSmtJobEntry.Transaction(ctx) + assert.Equal(t, tc.expectError, err != nil) + } +}