Skip to content

Commit

Permalink
Move interface to top level folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Deep1998 committed Feb 8, 2024
1 parent 07b047c commit a7f3d5f
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 3 deletions.
File renamed without changes.
8 changes: 5 additions & 3 deletions reverserepl/activity/create_smt_job_entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@ type CreateSmtJobEntryInput struct {

type CreateSmtJobEntry struct {
Input *CreateSmtJobEntryInput
DAO dao.DAO
SpA spanneraccessor.SpannerAccessor
}

// This creates an entry in the SMT job table.
// This creates a reverse replication entry in the SMT job table.
func (p *CreateSmtJobEntry) Transaction(ctx context.Context) error {
input := p.Input
dialect, err := spanneraccessor.GetDatabaseDialect(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", input.SpannerProjectId, input.InstanceId, input.DatabaseId))
dialect, err := p.SpA.GetDatabaseDialect(ctx, fmt.Sprintf("projects/%s/instances/%s/databases/%s", input.SpannerProjectId, input.InstanceId, input.DatabaseId))
if err != nil {
return fmt.Errorf("could not fetch database dialect: %v", err)
}
logger.Log.Debug(fmt.Sprintf("found database dialect: %s", dialect))
jobData := spanner.NullJSON{Valid: true, Value: input.JobData}
err = dao.InsertSMTJobEntry(ctx, input.SmtJobId, input.JobName, constants.REVERSE_REPLICATION_JOB_TYPE, dialect, input.DatabaseId, jobData)
err = p.DAO.InsertSMTJobEntry(ctx, input.SmtJobId, input.JobName, constants.REVERSE_REPLICATION_JOB_TYPE, dialect, input.DatabaseId, jobData)
if err != nil {
return err
}
Expand Down
109 changes: 109 additions & 0 deletions reverserepl/activity/create_smt_job_entry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package activity

import (
"context"
"fmt"
"os"
"testing"

"cloud.google.com/go/spanner"
spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner"
"github.com/GoogleCloudPlatform/spanner-migration-tool/dao"
"github.com/GoogleCloudPlatform/spanner-migration-tool/logger"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)

func init() {
logger.Log = zap.NewNop()
}

func TestMain(m *testing.M) {
res := m.Run()
os.Exit(res)
}

type SpannerAccessorMock struct {
spanneraccessor.SpannerAccessor
}

var getDatabaseDialectMock func(ctx context.Context, dbURI string) (string, error)

func (sam *SpannerAccessorMock) GetDatabaseDialect(ctx context.Context, dbURI string) (string, error) {
return getDatabaseDialectMock(ctx, dbURI)
}

type DAOMock struct {
dao.DAOImpl
}

var insertSMTJobEntryMock func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error

func (dao *DAOMock) InsertSMTJobEntry(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error {
return insertSMTJobEntryMock(ctx, jobId, jobName, jobType, dialect, dbName, jobData)
}

func TestCreateSmtJobEntryTransaction(t *testing.T) {
testCases := []struct {
name string
getDatabaseDialectMock func(ctx context.Context, dbURI string) (string, error)
insertSMTJobEntryMock func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error
expectError bool
}{
{
name: "No errors",
getDatabaseDialectMock: func(ctx context.Context, dbURI string) (string, error) {
return "", nil
},
insertSMTJobEntryMock: func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error {
return nil
},
expectError: false,
},
{
name: "Fetch Dialect error",
getDatabaseDialectMock: func(ctx context.Context, dbURI string) (string, error) {
return "", fmt.Errorf("test error")
},
insertSMTJobEntryMock: func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error {
return nil
},
expectError: true,
},
{
name: "Dao error",
getDatabaseDialectMock: func(ctx context.Context, dbURI string) (string, error) {
return "", nil
},
insertSMTJobEntryMock: func(ctx context.Context, jobId string, jobName string, jobType string, dialect string, dbName string, jobData spanner.NullJSON) error {
return fmt.Errorf("test error")
},
expectError: true,
},
}
ctx := context.Background()
createSmtJobEntry := CreateSmtJobEntry{
Input: &CreateSmtJobEntryInput{},
DAO: &DAOMock{},
SpA: &SpannerAccessorMock{},
}
for _, tc := range testCases {
getDatabaseDialectMock = tc.getDatabaseDialectMock
insertSMTJobEntryMock = tc.insertSMTJobEntryMock
err := createSmtJobEntry.Transaction(ctx)
assert.Equal(t, tc.expectError, err != nil)
}
}

0 comments on commit a7f3d5f

Please sign in to comment.